Repository: harubaru/discord-stable-diffusion Branch: main Commit: 5a7e5f6f6963 Files: 72 Total size: 557.1 KB Directory structure: gitextract_gdf_fm52/ ├── .gitignore ├── LICENSE ├── README.md ├── __main__.py ├── models/ │ ├── .keep │ └── v1-inference.yaml ├── requirements.txt ├── run.bat ├── run.sh ├── setup.bat ├── setup.sh ├── src/ │ ├── bot/ │ │ ├── shanghai.py │ │ └── stablecog.py │ ├── core/ │ │ └── logging.py │ ├── scripts/ │ │ └── win10patch.py │ └── stablediffusion/ │ ├── dream.py │ ├── inpaint.py │ ├── ldm/ │ │ ├── __init__.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── imagenet.py │ │ │ ├── lsun.py │ │ │ ├── personalized.py │ │ │ └── personalized_style.py │ │ ├── dream/ │ │ │ ├── conditioning.py │ │ │ ├── devices.py │ │ │ ├── generator/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── img2img.py │ │ │ │ ├── inpaint.py │ │ │ │ └── txt2img.py │ │ │ ├── image_util.py │ │ │ ├── pngwriter.py │ │ │ ├── readline.py │ │ │ └── server.py │ │ ├── generate.py │ │ ├── gfpgan/ │ │ │ └── gfpgan_tools.py │ │ ├── lr_scheduler.py │ │ ├── models/ │ │ │ ├── autoencoder.py │ │ │ └── diffusion/ │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── ksampler.py │ │ │ └── plms.py │ │ ├── modules/ │ │ │ ├── attention.py │ │ │ ├── diffusionmodules/ │ │ │ │ ├── __init__.py │ │ │ │ ├── model.py │ │ │ │ ├── openaimodel.py │ │ │ │ └── util.py │ │ │ ├── distributions/ │ │ │ │ ├── __init__.py │ │ │ │ └── distributions.py │ │ │ ├── ema.py │ │ │ ├── embedding_manager.py │ │ │ ├── encoders/ │ │ │ │ ├── __init__.py │ │ │ │ └── modules.py │ │ │ ├── image_degradation/ │ │ │ │ ├── __init__.py │ │ │ │ ├── bsrgan.py │ │ │ │ ├── bsrgan_light.py │ │ │ │ └── utils_image.py │ │ │ ├── losses/ │ │ │ │ ├── __init__.py │ │ │ │ ├── contperceptual.py │ │ │ │ └── vqperceptual.py │ │ │ └── x_transformer.py │ │ ├── simplet2i.py │ │ └── util.py │ ├── text2image_compvis.py │ ├── text2image_diffusers.py │ └── translation.py ├── storage/ │ ├── init/ │ │ └── .keep │ └── outputs/ │ └── .keep └── win10fix.bat ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ storage/outputs/*.png storage/init/*.png # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class log.txt # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ waifu-diffusion/ ================================================ FILE: LICENSE ================================================ GNU GENERAL PUBLIC LICENSE Version 2, June 1991 Copyright (C) 1989, 1991 Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This General Public License applies to most of the Free Software Foundation's software and to any other program whose authors commit to using it. (Some other Free Software Foundation software is covered by the GNU Lesser General Public License instead.) You can apply it to your programs, too. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs; and that you know you can do these things. To protect your rights, we need to make restrictions that forbid anyone to deny you these rights or to ask you to surrender the rights. These restrictions translate to certain responsibilities for you if you distribute copies of the software, or if you modify it. For example, if you distribute copies of such a program, whether gratis or for a fee, you must give the recipients all the rights that you have. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. We protect your rights with two steps: (1) copyright the software, and (2) offer you this license which gives you legal permission to copy, distribute and/or modify the software. Also, for each author's protection and ours, we want to make certain that everyone understands that there is no warranty for this free software. If the software is modified by someone else and passed on, we want its recipients to know that what they have is not the original, so that any problems introduced by others will not reflect on the original authors' reputations. Finally, any free program is threatened constantly by software patents. We wish to avoid the danger that redistributors of a free program will individually obtain patent licenses, in effect making the program proprietary. To prevent this, we have made it clear that any patent must be licensed for everyone's free use or not licensed at all. The precise terms and conditions for copying, distribution and modification follow. GNU GENERAL PUBLIC LICENSE TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 0. This License applies to any program or other work which contains a notice placed by the copyright holder saying it may be distributed under the terms of this General Public License. The "Program", below, refers to any such program or work, and a "work based on the Program" means either the Program or any derivative work under copyright law: that is to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or translated into another language. (Hereinafter, translation is included without limitation in the term "modification".) Each licensee is addressed as "you". Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running the Program is not restricted, and the output from the Program is covered only if its contents constitute a work based on the Program (independent of having been made by running the Program). Whether that is true depends on what the Program does. 1. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and give any other recipients of the Program a copy of this License along with the Program. You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. 2. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on the Program, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: a) You must cause the modified files to carry prominent notices stating that you changed the files and the date of any change. b) You must cause any work that you distribute or publish, that in whole or in part contains or is derived from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under the terms of this License. c) If the modified program normally reads commands interactively when run, you must cause it, when started running for such interactive use in the most ordinary way, to print or display an announcement including an appropriate copyright notice and a notice that there is no warranty (or else, saying that you provide a warranty) and that users may redistribute the program under these conditions, and telling the user how to view a copy of this License. (Exception: if the Program itself is interactive but does not normally print such an announcement, your work based on the Program is not required to print an announcement.) These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Program, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Program, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Program. In addition, mere aggregation of another work not based on the Program with the Program (or with a work based on the Program) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. 3. You may copy and distribute the Program (or a work based on it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you also do one of the following: a) Accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, b) Accompany it with a written offer, valid for at least three years, to give any third party, for a charge no more than your cost of physically performing source distribution, a complete machine-readable copy of the corresponding source code, to be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, c) Accompany it with the information you received as to the offer to distribute corresponding source code. (This alternative is allowed only for noncommercial distribution and only if you received the program in object code or executable form with such an offer, in accord with Subsection b above.) The source code for a work means the preferred form of the work for making modifications to it. For an executable work, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the executable. However, as a special exception, the source code distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. If distribution of executable or object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place counts as distribution of the source code, even though third parties are not compelled to copy the source along with the object code. 4. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. 5. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Program or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the Program), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Program or works based on it. 6. Each time you redistribute the Program (or any work based on the Program), the recipient automatically receives a license from the original licensor to copy, distribute or modify the Program subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties to this License. 7. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a patent license would not permit royalty-free redistribution of the Program by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Program. If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply and the section as a whole is intended to apply in other circumstances. It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system, which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. 8. If the distribution and/or use of the Program is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Program under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. 9. The Free Software Foundation may publish revised and/or new versions of the General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of this License, you may choose any version ever published by the Free Software Foundation. 10. If you wish to incorporate parts of the Program into other free programs whose distribution conditions are different, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. NO WARRANTY 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. Also add information on how to contact you by electronic and paper mail. If the program is interactive, make it output a short notice like this when it starts in an interactive mode: Gnomovision version 69, Copyright (C) year name of author Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, the commands you use may be called something other than `show w' and `show c'; they could even be mouse-clicks or menu items--whatever suits your program. You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the program, if necessary. Here is a sample; alter the names: Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. , 1 April 1989 Ty Coon, President of Vice This General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. ================================================ FILE: README.md ================================================ # Shanghai - AI Powered Art in a Discord Bot! ### Any questions or need help? Come hop on by to our Discord server! [![Discord Server](https://discordapp.com/api/guilds/930499730843250783/widget.png?style=banner2)](https://discord.gg/Sx6Spmsgx7) ## Setup Make sure you have the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) installed Clone the repository and enter it ```` git clone https://github.com/harubaru/discord-stable-diffusion.git cd discord-stable-diffusion ```` #### WINDOWS SETUP Run `setup.bat`. If you run into any errors, try running the file as administrator If you are on a Windows 10 system, run `win10patch.bat` Modify the `run.bat` file, where * `--model_path` is the path to the model (make sure to replace any backslashes with double backslashes), * `--token=` is the token to the Discord bot * `--hf_token=` is your huggingface token (can be found [here](https://huggingface.co/settings/tokens)) Run the `run.bat` file #### LINUX SETUP Run `./setup.sh`. If you run into any errors, try using `sudo ./setup.sh` Modify the `run.sh` file, where * `--model_path` is the path to the model, * `--token=` is the token to the Discord bot * `--hf_token=` is your huggingface token (can be found [here](https://huggingface.co/settings/tokens)) Run `./run.sh` ### Quickstart #### Text to Image To generate an image from text, use the ``/dream`` command and include your prompt as the query. There's tons of parameters to play with so go wild! ![image](https://user-images.githubusercontent.com/26317155/186722689-3cbca12a-531c-47f7-b87f-99918e9ed232.png) ![image](https://user-images.githubusercontent.com/26317155/186721768-3684f629-90c3-4ef2-82b8-1310200df437.png) #### Image to Image To generate an image from another image, use the ``/dream`` command and include the `init_image` and `strength` parameters. The image needs to be attached to the message. ![image](https://user-images.githubusercontent.com/26317155/186722463-ec3a6d24-36c1-48f8-b09a-57651706848c.png) ![image](https://user-images.githubusercontent.com/26317155/186722528-7e652a21-fd02-4071-9fc1-87a31dfb6e63.png) #### (Experimental) Inpainting To fill in a mask in an image, supply a prompt, the `init_image`, `mask_image` and `strength` parameters. The mask needs to consist of black pixels in a transparent image. ![image](https://user-images.githubusercontent.com/26317155/186722970-71a662dc-16a8-4bb4-8696-3bafb3e08e65.png) ================================================ FILE: __main__.py ================================================ import os import sys import argparse import asyncio from src.core.logging import get_logger from src.bot.shanghai import Shanghai logger = get_logger(__name__) def parse_args(): parser = argparse.ArgumentParser( description='Shanghai - A Discord bot for AI powered utilities.', usage='shanghai [arguments]' ) parser.add_argument('--prefix', type=str, help='The prefix to use for commands.', default='s!') parser.add_argument('--token', type=str, help='The token to use for authentication.') parser.add_argument('--hf_token', type=str, help='The token to use for HuggingFace authentication.', default=None) parser.add_argument('--model_path', type=str, help='Path to the model.', default=None) return parser.parse_args() async def shutdown(bot): await bot.close() def main(): shanghai = None args = parse_args() try: shanghai = Shanghai(args) logger.info('Executing bot.') shanghai.run(args.token) except KeyboardInterrupt: logger.info('Keyboard interrupt received. Exiting.') asyncio.run(shutdown(shanghai)) except SystemExit: logger.info('System exit received. Exiting.') asyncio.run(shutdown(shanghai)) except Exception as e: logger.error(e) asyncio.run(shutdown(shanghai)) finally: sys.exit(0) if __name__ == '__main__': main() ================================================ FILE: models/.keep ================================================ 壊れたカーテンの隙間から 壁を埋めるのは 暴言?妄言?知りません。 ================================================ FILE: models/v1-inference.yaml ================================================ model: base_learning_rate: 1.0e-04 target: src.stablediffusion.ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: "jpg" cond_stage_key: "txt" image_size: 64 channels: 4 cond_stage_trainable: false # Note: different from the one we trained before conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215 use_ema: False scheduler_config: # 10000 warmup steps target: src.stablediffusion.ldm.lr_scheduler.LambdaLinearScheduler params: warm_up_steps: [ 10000 ] cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases f_start: [ 1.e-6 ] f_max: [ 1. ] f_min: [ 1. ] personalization_config: target: src.stablediffusion.ldm.modules.embedding_manager.EmbeddingManager params: placeholder_strings: ["*"] initializer_words: ["sculpture"] per_image_tokens: false num_vectors_per_token: 1 progressive_words: False unet_config: target: src.stablediffusion.ldm.modules.diffusionmodules.openaimodel.UNetModel params: image_size: 32 # unused in_channels: 4 out_channels: 4 model_channels: 320 attention_resolutions: [ 4, 2, 1 ] num_res_blocks: 2 channel_mult: [ 1, 2, 4, 4 ] num_heads: 8 use_spatial_transformer: True transformer_depth: 1 context_dim: 768 use_checkpoint: True legacy: False first_stage_config: target: src.stablediffusion.ldm.models.autoencoder.AutoencoderKL params: embed_dim: 4 monitor: val/rec_loss ddconfig: double_z: true z_channels: 4 resolution: 256 in_channels: 3 out_ch: 3 ch: 128 ch_mult: - 1 - 2 - 4 - 4 num_res_blocks: 2 attn_resolutions: [] dropout: 0.0 lossconfig: target: torch.nn.Identity cond_stage_config: target: src.stablediffusion.ldm.modules.encoders.modules.FrozenCLIPEmbedder ================================================ FILE: requirements.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu117 torch diffusers numpy Pillow pydantic git+https://github.com/Pycord-Development/pycord omegaconf==2.1.1 pytorch-lightning==1.4.2 taming-transformers-rom1504==0.0.6 test-tube>=0.7.5 torch-fidelity==0.3.0 torchmetrics==0.6.0 transformers==4.19.2 git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion ================================================ FILE: run.bat ================================================ venv\Scripts\python.exe . --model_path "" --token="" ================================================ FILE: run.sh ================================================ venv/bin/python . --model_path "" --token="" --hf_token="" ================================================ FILE: setup.bat ================================================ python -m venv venv venv\Scripts\pip.exe install -r requirements.txt ================================================ FILE: setup.sh ================================================ python -m venv venv venv/bin/pip install -r requirements.txt ================================================ FILE: src/bot/shanghai.py ================================================ import asyncio import os from abc import ABC import discord from discord.ext import commands from src.core.logging import get_logger class Shanghai(commands.Bot, ABC): def __init__(self, args): intents = discord.Intents.default() intents.members = True super().__init__(command_prefix=args.prefix, intents=intents) self.args = args self.logger = get_logger(__name__) self.load_extension('src.bot.stablecog') async def on_ready(self): self.logger.info(f'Logged in as {self.user.name} ({self.user.id})') await self.change_presence( activity=discord.Activity(type=discord.ActivityType.watching, name='you over the seven seas.')) async def on_message(self, message): if message.author == self.user: try: # Check if the message from Shanghai was actually a generation if message.embeds[0].fields[0].name == 'command': await message.add_reaction('❌') except: pass async def on_raw_reaction_add(self, ctx): if ctx.emoji.name == '❌': message = await self.get_channel(ctx.channel_id).fetch_message(ctx.message_id) if message.embeds: # look at the message footer to see if the generation was by the user who reacted if message.embeds[0].footer.text == f'{ctx.member.name}#{ctx.member.discriminator}': await message.delete() ================================================ FILE: src/bot/stablecog.py ================================================ import traceback from asyncio import AbstractEventLoop from threading import Thread import requests import asyncio import discord from discord.ext import commands from typing import Optional from io import BytesIO from PIL import Image from discord import option import random import time from src.stablediffusion.text2image_compvis import Text2Image embed_color = discord.Colour.from_rgb(215, 195, 134) class QueueObject: def __init__(self, ctx, prompt, height, width, guidance_scale, steps, seed, strength, init_image, mask_image, sampler_name, command_str): self.ctx = ctx self.prompt = prompt self.height = height self.width = width self.guidance_scale = guidance_scale self.steps = steps self.seed = seed self.strength = strength self.init_image = init_image self.mask_image = mask_image self.sampler_name = sampler_name self.command_str = command_str class StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'): def __init__(self, bot): self.dream_thread = Thread() self.text2image_model = Text2Image(model_path=bot.args.model_path) self.event_loop = asyncio.get_event_loop() self.queue = [] self.bot = bot @commands.slash_command(name='dream', description='Create an image.') @option( 'prompt', str, description='A prompt to condition the model with.', required=True, ) @option( 'height', int, description='Height of the generated image.', required=False, choices=[x for x in range(192, 832, 64)] ) @option( 'width', int, description='Width of the generated image.', required=False, choices=[x for x in range(192, 832, 64)] ) @option( 'guidance_scale', float, description='Classifier-Free Guidance scale', required=False, ) @option( 'steps', int, description='The amount of steps to sample the model', required=False, choices=[x for x in range(5, 55, 5)] ) @option( 'sampler', str, description='The sampler to use for generation', required=False, choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], default='ddim' ) @option( 'seed', int, description='The seed to use for reproduceability', required=False, ) @option( 'strength', float, description='The strength (0.0 to 1.0) used to apply the prompt to the init_image/mask_image' ) @option( 'init_image', discord.Attachment, description='The image to initialize the latents with for denoising', required=False, ) @option( 'mask_image', discord.Attachment, description='The mask image to use for inpainting', required=False, ) async def dream_handler(self, ctx: discord.ApplicationContext, *, prompt: str, height: Optional[int] = 512, width: Optional[int] = 512, guidance_scale: Optional[float] = 7.0, steps: Optional[int] = 30, sampler: Optional[str] = 'k_euler_a', seed: Optional[int] = -1, strength: Optional[float] = None, init_image: Optional[discord.Attachment] = None, mask_image: Optional[discord.Attachment] = None): print(f'Request -- {ctx.author.name}#{ctx.author.discriminator} -- Prompt: {prompt}') if seed == -1: seed = random.randint(0, 0xFFFFFFFF) command_str = '/dream' command_str = command_str + f' prompt:{prompt} height:{str(height)} width:{width} guidance_scale:{guidance_scale} steps:{steps} sampler:{sampler} seed:{seed}' if init_image or mask_image: command_str = command_str + f' strength:{strength}' if self.dream_thread.is_alive(): user_already_in_queue = False for queue_object in self.queue: if queue_object.ctx.author.id == ctx.author.id: user_already_in_queue = True break if user_already_in_queue: await ctx.send_response( content=f'Please wait for your current image to finish generating before generating a new image', ephemeral=True) else: self.queue.append(QueueObject(ctx, prompt, height, width, guidance_scale, steps, seed, strength, init_image, mask_image, sampler, command_str)) await ctx.send_response( content=f'Dreaming for <@{ctx.author.id}> - Queue Position: ``{len(self.queue)}`` - ``{command_str}``') else: await self.process_dream(QueueObject(ctx, prompt, height, width, guidance_scale, steps, seed, strength, init_image, mask_image, sampler, command_str)) await ctx.send_response( content=f'Dreaming for <@{ctx.author.id}> - Queue Position: ``{len(self.queue)}`` - ``{command_str}``') async def process_dream(self, queue_object: QueueObject): self.dream_thread = Thread(target=self.dream, args=(self.event_loop, queue_object)) self.dream_thread.start() def dream(self, event_loop: AbstractEventLoop, queue_object: QueueObject): try: start_time = time.time() if (queue_object.init_image is None) and (queue_object.mask_image is None): samples, seed = self.text2image_model.dream(queue_object.prompt, queue_object.steps, False, False, 0.0, 1, 1, queue_object.guidance_scale, queue_object.seed, queue_object.height, queue_object.width, False, queue_object.sampler_name) elif queue_object.init_image is not None: image = Image.open(requests.get(queue_object.init_image.url, stream=True).raw).convert('RGB') samples, seed = self.text2image_model.translation(queue_object.prompt, image, queue_object.steps, 0.0, 0, 0, queue_object.guidance_scale, queue_object.strength, queue_object.seed, queue_object.height, queue_object.width, queue_object.sampler_name) else: image = Image.open(requests.get(queue_object.init_image.url, stream=True).raw).convert('RGB') mask = Image.open(requests.get(queue_object.mask_image.url, stream=True).raw).convert('RGB') samples, seed = self.text2image_model.inpaint(queue_object.prompt, image, mask, queue_object.steps, 0.0, 1, 1, queue_object.guidance_scale, denoising_strength=queue_object.strength, seed=queue_object.seed, height=queue_object.height, width=queue_object.width, sampler_name=queue_object.sampler_name) end_time = time.time() with BytesIO() as buffer: samples[0].save(buffer, 'PNG') buffer.seek(0) embed = discord.Embed() embed.colour = embed_color embed.add_field(name='command', value=f'``{queue_object.command_str}``', inline=False) embed.add_field(name='compute used', value='``{0:.3f}`` seconds'.format(end_time - start_time), inline=False) embed.add_field(name='delete', value='React with ❌ to delete your own generation') # fix errors if user doesn't have pfp if queue_object.ctx.author.avatar is None: embed.set_footer( text=f'{queue_object.ctx.author.name}#{queue_object.ctx.author.discriminator}') else: embed.set_footer( text=f'{queue_object.ctx.author.name}#{queue_object.ctx.author.discriminator}', icon_url=queue_object.ctx.author.avatar.url) event_loop.create_task( queue_object.ctx.channel.send(content=f'<@{queue_object.ctx.author.id}>', embed=embed, file=discord.File(fp=buffer, filename=f'{seed}.png'))) except Exception as e: embed = discord.Embed(title='txt2img failed', description=f'{e}\n{traceback.print_exc()}', color=embed_color) event_loop.create_task(queue_object.ctx.channel.send(embed=embed)) if self.queue: event_loop.create_task(self.process_dream(self.queue.pop(0))) def setup(bot): bot.add_cog(StableCog(bot)) ================================================ FILE: src/core/logging.py ================================================ import logging logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s: %(message)s', datefmt='%Y-%m-%d %H:%M:%S') def get_logger(name): return logging.getLogger(name) ================================================ FILE: src/scripts/win10patch.py ================================================ try: file_path = 'venv\\lib\\site-packages\\torch\\distributed\\elastic\\timer\\file_based_local_timer.py' with open(file_path, 'r+') as file: old = file.read() if 'SIGKILL' not in old: print(file_path + ' already patched!') exit(0) file.seek(0) file.write(old.replace('SIGKILL', 'SIGINT')) print('Patched ' + file_path) except Exception as e: print('Patch failed! Please report this either on github or to salt#7234\nReason: ' + str(e)) ================================================ FILE: src/stablediffusion/dream.py ================================================ import inspect import warnings from typing import List, Optional, Union import torch from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from PIL import Image class StableDiffusionPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] ): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = 512, width: Optional[int] = 512, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", progress: Optional[bool] = False, **kwargs, ): if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn( "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." " Consider using `pipe.to(torch_device)` instead." ) # Set device as before (to be removed in 0.3.0) if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.to(device) if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") # get prompt text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the intial random noise latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, device=self.device, ) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} if accepts_offset: extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta images = [] for i, t in tqdm(enumerate(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] else: latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] if progress: latent_image = self.vae.decode(1 / 0.18215 * latents) latent_image = (latent_image / 2 + 0.5).clamp(0, 1) latent_image = latent_image.cpu().permute(0, 2, 3, 1).numpy() if latent_image.ndim == 3: latent_image = latent_image[None, ...] latent_image = (latent_image * 255).round().astype('uint8') latent_image = [Image.fromarray(image) for image in latent_image] images.append(latent_image[0]) if progress: images[0].save(f'output.gif', save_all=True, append_images=images[1:], optimize=False, loop=0, duration=125) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) return {"sample": image} ================================================ FILE: src/stablediffusion/inpaint.py ================================================ import inspect from typing import List, Optional, Union import numpy as np import torch import PIL from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 def preprocess_mask(mask): mask=mask.convert("L") w, h = mask.size mask = mask.resize((int(w / 8), int(h / 8)), resample=PIL.Image.LANCZOS) mask = np.array(mask).astype(np.float32) / 255.0 mask = np.tile(mask,(4,1,1)) mask = mask[None].transpose(0, 1, 2, 3)#what does this step do? mask = torch.from_numpy(mask).bool() return (mask).long() class StableDiffusionInpaintingPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler], ): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], init_image: torch.FloatTensor, mask_image: torch.FloatTensor, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", ): if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} offset = 0 if accepts_offset: offset = 1 extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # encode the init image into latents and scale the latents init_latents = self.vae.encode(init_image.to(self.device)).sample() init_latents = 0.18215 * init_latents init_latents_orig = init_latents # prepare init_latents noise to latents init_latents = torch.cat([init_latents] * batch_size) # preprocess mask mask = preprocess_mask(mask_image).to(self.device) mask = torch.cat([mask] * batch_size) # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] #masking init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) latents = ( init_latents_proper * mask ) + ( latents * (1-mask) ) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) return {"sample": image, "nsfw_content_detected": False} ================================================ FILE: src/stablediffusion/ldm/__init__.py ================================================ from .generate import Generate ================================================ FILE: src/stablediffusion/ldm/data/__init__.py ================================================ ================================================ FILE: src/stablediffusion/ldm/data/base.py ================================================ from abc import abstractmethod from torch.utils.data import ( Dataset, ConcatDataset, ChainDataset, IterableDataset, ) class Txt2ImgIterableBaseDataset(IterableDataset): """ Define an interface to make the IterableDatasets for text2img data chainable """ def __init__(self, num_records=0, valid_ids=None, size=256): super().__init__() self.num_records = num_records self.valid_ids = valid_ids self.sample_ids = valid_ids self.size = size print( f'{self.__class__.__name__} dataset contains {self.__len__()} examples.' ) def __len__(self): return self.num_records @abstractmethod def __iter__(self): pass ================================================ FILE: src/stablediffusion/ldm/data/imagenet.py ================================================ import os, yaml, pickle, shutil, tarfile, glob import cv2 import albumentations import PIL import numpy as np import torchvision.transforms.functional as TF from omegaconf import OmegaConf from functools import partial from PIL import Image from tqdm import tqdm from torch.utils.data import Dataset, Subset import taming.data.utils as tdu from taming.data.imagenet import ( str_to_indices, give_synsets_from_indices, download, retrieve, ) from taming.data.imagenet import ImagePaths from ldm.modules.image_degradation import ( degradation_fn_bsr, degradation_fn_bsr_light, ) def synset2idx(path_to_yaml='data/index_synset.yaml'): with open(path_to_yaml) as f: di2s = yaml.load(f) return dict((v, k) for k, v in di2s.items()) class ImageNetBase(Dataset): def __init__(self, config=None): self.config = config or OmegaConf.create() if not type(self.config) == dict: self.config = OmegaConf.to_container(self.config) self.keep_orig_class_label = self.config.get( 'keep_orig_class_label', False ) self.process_images = True # if False we skip loading & processing images and self.data contains filepaths self._prepare() self._prepare_synset_to_human() self._prepare_idx_to_synset() self._prepare_human_to_integer_label() self._load() def __len__(self): return len(self.data) def __getitem__(self, i): return self.data[i] def _prepare(self): raise NotImplementedError() def _filter_relpaths(self, relpaths): ignore = set( [ 'n06596364_9591.JPEG', ] ) relpaths = [ rpath for rpath in relpaths if not rpath.split('/')[-1] in ignore ] if 'sub_indices' in self.config: indices = str_to_indices(self.config['sub_indices']) synsets = give_synsets_from_indices( indices, path_to_yaml=self.idx2syn ) # returns a list of strings self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) files = [] for rpath in relpaths: syn = rpath.split('/')[0] if syn in synsets: files.append(rpath) return files else: return relpaths def _prepare_synset_to_human(self): SIZE = 2655750 URL = 'https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1' self.human_dict = os.path.join(self.root, 'synset_human.txt') if ( not os.path.exists(self.human_dict) or not os.path.getsize(self.human_dict) == SIZE ): download(URL, self.human_dict) def _prepare_idx_to_synset(self): URL = 'https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1' self.idx2syn = os.path.join(self.root, 'index_synset.yaml') if not os.path.exists(self.idx2syn): download(URL, self.idx2syn) def _prepare_human_to_integer_label(self): URL = 'https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1' self.human2integer = os.path.join( self.root, 'imagenet1000_clsidx_to_labels.txt' ) if not os.path.exists(self.human2integer): download(URL, self.human2integer) with open(self.human2integer, 'r') as f: lines = f.read().splitlines() assert len(lines) == 1000 self.human2integer_dict = dict() for line in lines: value, key = line.split(':') self.human2integer_dict[key] = int(value) def _load(self): with open(self.txt_filelist, 'r') as f: self.relpaths = f.read().splitlines() l1 = len(self.relpaths) self.relpaths = self._filter_relpaths(self.relpaths) print( 'Removed {} files from filelist during filtering.'.format( l1 - len(self.relpaths) ) ) self.synsets = [p.split('/')[0] for p in self.relpaths] self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] unique_synsets = np.unique(self.synsets) class_dict = dict( (synset, i) for i, synset in enumerate(unique_synsets) ) if not self.keep_orig_class_label: self.class_labels = [class_dict[s] for s in self.synsets] else: self.class_labels = [self.synset2idx[s] for s in self.synsets] with open(self.human_dict, 'r') as f: human_dict = f.read().splitlines() human_dict = dict(line.split(maxsplit=1) for line in human_dict) self.human_labels = [human_dict[s] for s in self.synsets] labels = { 'relpath': np.array(self.relpaths), 'synsets': np.array(self.synsets), 'class_label': np.array(self.class_labels), 'human_label': np.array(self.human_labels), } if self.process_images: self.size = retrieve(self.config, 'size', default=256) self.data = ImagePaths( self.abspaths, labels=labels, size=self.size, random_crop=self.random_crop, ) else: self.data = self.abspaths class ImageNetTrain(ImageNetBase): NAME = 'ILSVRC2012_train' URL = 'http://www.image-net.org/challenges/LSVRC/2012/' AT_HASH = 'a306397ccf9c2ead27155983c254227c0fd938e2' FILES = [ 'ILSVRC2012_img_train.tar', ] SIZES = [ 147897477120, ] def __init__(self, process_images=True, data_root=None, **kwargs): self.process_images = process_images self.data_root = data_root super().__init__(**kwargs) def _prepare(self): if self.data_root: self.root = os.path.join(self.data_root, self.NAME) else: cachedir = os.environ.get( 'XDG_CACHE_HOME', os.path.expanduser('~/.cache') ) self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME) self.datadir = os.path.join(self.root, 'data') self.txt_filelist = os.path.join(self.root, 'filelist.txt') self.expected_length = 1281167 self.random_crop = retrieve( self.config, 'ImageNetTrain/random_crop', default=True ) if not tdu.is_prepared(self.root): # prep print('Preparing dataset {} in {}'.format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if ( not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0] ): import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path print('Extracting {} to {}'.format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) print('Extracting sub-tars.') subpaths = sorted(glob.glob(os.path.join(datadir, '*.tar'))) for subpath in tqdm(subpaths): subdir = subpath[: -len('.tar')] os.makedirs(subdir, exist_ok=True) with tarfile.open(subpath, 'r:') as tar: tar.extractall(path=subdir) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) filelist = '\n'.join(filelist) + '\n' with open(self.txt_filelist, 'w') as f: f.write(filelist) tdu.mark_prepared(self.root) class ImageNetValidation(ImageNetBase): NAME = 'ILSVRC2012_validation' URL = 'http://www.image-net.org/challenges/LSVRC/2012/' AT_HASH = '5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5' VS_URL = 'https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1' FILES = [ 'ILSVRC2012_img_val.tar', 'validation_synset.txt', ] SIZES = [ 6744924160, 1950000, ] def __init__(self, process_images=True, data_root=None, **kwargs): self.data_root = data_root self.process_images = process_images super().__init__(**kwargs) def _prepare(self): if self.data_root: self.root = os.path.join(self.data_root, self.NAME) else: cachedir = os.environ.get( 'XDG_CACHE_HOME', os.path.expanduser('~/.cache') ) self.root = os.path.join(cachedir, 'autoencoders/data', self.NAME) self.datadir = os.path.join(self.root, 'data') self.txt_filelist = os.path.join(self.root, 'filelist.txt') self.expected_length = 50000 self.random_crop = retrieve( self.config, 'ImageNetValidation/random_crop', default=False ) if not tdu.is_prepared(self.root): # prep print('Preparing dataset {} in {}'.format(self.NAME, self.root)) datadir = self.datadir if not os.path.exists(datadir): path = os.path.join(self.root, self.FILES[0]) if ( not os.path.exists(path) or not os.path.getsize(path) == self.SIZES[0] ): import academictorrents as at atpath = at.get(self.AT_HASH, datastore=self.root) assert atpath == path print('Extracting {} to {}'.format(path, datadir)) os.makedirs(datadir, exist_ok=True) with tarfile.open(path, 'r:') as tar: tar.extractall(path=datadir) vspath = os.path.join(self.root, self.FILES[1]) if ( not os.path.exists(vspath) or not os.path.getsize(vspath) == self.SIZES[1] ): download(self.VS_URL, vspath) with open(vspath, 'r') as f: synset_dict = f.read().splitlines() synset_dict = dict(line.split() for line in synset_dict) print('Reorganizing into synset folders') synsets = np.unique(list(synset_dict.values())) for s in synsets: os.makedirs(os.path.join(datadir, s), exist_ok=True) for k, v in synset_dict.items(): src = os.path.join(datadir, k) dst = os.path.join(datadir, v) shutil.move(src, dst) filelist = glob.glob(os.path.join(datadir, '**', '*.JPEG')) filelist = [os.path.relpath(p, start=datadir) for p in filelist] filelist = sorted(filelist) filelist = '\n'.join(filelist) + '\n' with open(self.txt_filelist, 'w') as f: f.write(filelist) tdu.mark_prepared(self.root) class ImageNetSR(Dataset): def __init__( self, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.0, random_crop=True, ): """ Imagenet Superresolution Dataloader Performs following ops in order: 1. crops a crop of size s from image either as random or center crop 2. resizes crop to size with cv2.area_interpolation 3. degrades resized crop with degradation_fn :param size: resizing to size after cropping :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light :param downscale_f: Low Resolution Downsample factor :param min_crop_f: determines crop size s, where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) :param max_crop_f: "" :param data_root: :param random_crop: """ self.base = self.get_base() assert size assert (size / downscale_f).is_integer() self.size = size self.LR_size = int(size / downscale_f) self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f assert max_crop_f <= 1.0 self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize( max_size=size, interpolation=cv2.INTER_AREA ) self.pil_interpolation = ( False # gets reset later if incase interp_op is from pillow ) if degradation == 'bsrgan': self.degradation_process = partial( degradation_fn_bsr, sf=downscale_f ) elif degradation == 'bsrgan_light': self.degradation_process = partial( degradation_fn_bsr_light, sf=downscale_f ) else: interpolation_fn = { 'cv_nearest': cv2.INTER_NEAREST, 'cv_bilinear': cv2.INTER_LINEAR, 'cv_bicubic': cv2.INTER_CUBIC, 'cv_area': cv2.INTER_AREA, 'cv_lanczos': cv2.INTER_LANCZOS4, 'pil_nearest': PIL.Image.NEAREST, 'pil_bilinear': PIL.Image.BILINEAR, 'pil_bicubic': PIL.Image.BICUBIC, 'pil_box': PIL.Image.BOX, 'pil_hamming': PIL.Image.HAMMING, 'pil_lanczos': PIL.Image.LANCZOS, }[degradation] self.pil_interpolation = degradation.startswith('pil_') if self.pil_interpolation: self.degradation_process = partial( TF.resize, size=self.LR_size, interpolation=interpolation_fn, ) else: self.degradation_process = albumentations.SmallestMaxSize( max_size=self.LR_size, interpolation=interpolation_fn ) def __len__(self): return len(self.base) def __getitem__(self, i): example = self.base[i] image = Image.open(example['file_path_']) if not image.mode == 'RGB': image = image.convert('RGB') image = np.array(image).astype(np.uint8) min_side_len = min(image.shape[:2]) crop_side_len = min_side_len * np.random.uniform( self.min_crop_f, self.max_crop_f, size=None ) crop_side_len = int(crop_side_len) if self.center_crop: self.cropper = albumentations.CenterCrop( height=crop_side_len, width=crop_side_len ) else: self.cropper = albumentations.RandomCrop( height=crop_side_len, width=crop_side_len ) image = self.cropper(image=image)['image'] image = self.image_rescaler(image=image)['image'] if self.pil_interpolation: image_pil = PIL.Image.fromarray(image) LR_image = self.degradation_process(image_pil) LR_image = np.array(LR_image).astype(np.uint8) else: LR_image = self.degradation_process(image=image)['image'] example['image'] = (image / 127.5 - 1.0).astype(np.float32) example['LR_image'] = (LR_image / 127.5 - 1.0).astype(np.float32) return example class ImageNetSRTrain(ImageNetSR): def __init__(self, **kwargs): super().__init__(**kwargs) def get_base(self): with open('data/imagenet_train_hr_indices.p', 'rb') as f: indices = pickle.load(f) dset = ImageNetTrain( process_images=False, ) return Subset(dset, indices) class ImageNetSRValidation(ImageNetSR): def __init__(self, **kwargs): super().__init__(**kwargs) def get_base(self): with open('data/imagenet_val_hr_indices.p', 'rb') as f: indices = pickle.load(f) dset = ImageNetValidation( process_images=False, ) return Subset(dset, indices) ================================================ FILE: src/stablediffusion/ldm/data/lsun.py ================================================ import os import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class LSUNBase(Dataset): def __init__( self, txt_file, data_root, size=None, interpolation='bicubic', flip_p=0.5, ): self.data_paths = txt_file self.data_root = data_root with open(self.data_paths, 'r') as f: self.image_paths = f.read().splitlines() self._length = len(self.image_paths) self.labels = { 'relative_file_path_': [l for l in self.image_paths], 'file_path_': [ os.path.join(self.data_root, l) for l in self.image_paths ], } self.size = size self.interpolation = { 'linear': PIL.Image.LINEAR, 'bilinear': PIL.Image.BILINEAR, 'bicubic': PIL.Image.BICUBIC, 'lanczos': PIL.Image.LANCZOS, }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): return self._length def __getitem__(self, i): example = dict((k, self.labels[k][i]) for k in self.labels) image = Image.open(example['file_path_']) if not image.mode == 'RGB': image = image.convert('RGB') # default to score-sde preprocessing img = np.array(image).astype(np.uint8) crop = min(img.shape[0], img.shape[1]) h, w, = ( img.shape[0], img.shape[1], ) img = img[ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2, ] image = Image.fromarray(img) if self.size is not None: image = image.resize( (self.size, self.size), resample=self.interpolation ) image = self.flip(image) image = np.array(image).astype(np.uint8) example['image'] = (image / 127.5 - 1.0).astype(np.float32) return example class LSUNChurchesTrain(LSUNBase): def __init__(self, **kwargs): super().__init__( txt_file='data/lsun/church_outdoor_train.txt', data_root='data/lsun/churches', **kwargs ) class LSUNChurchesValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): super().__init__( txt_file='data/lsun/church_outdoor_val.txt', data_root='data/lsun/churches', flip_p=flip_p, **kwargs ) class LSUNBedroomsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__( txt_file='data/lsun/bedrooms_train.txt', data_root='data/lsun/bedrooms', **kwargs ) class LSUNBedroomsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): super().__init__( txt_file='data/lsun/bedrooms_val.txt', data_root='data/lsun/bedrooms', flip_p=flip_p, **kwargs ) class LSUNCatsTrain(LSUNBase): def __init__(self, **kwargs): super().__init__( txt_file='data/lsun/cat_train.txt', data_root='data/lsun/cats', **kwargs ) class LSUNCatsValidation(LSUNBase): def __init__(self, flip_p=0.0, **kwargs): super().__init__( txt_file='data/lsun/cat_val.txt', data_root='data/lsun/cats', flip_p=flip_p, **kwargs ) ================================================ FILE: src/stablediffusion/ldm/data/personalized.py ================================================ import os import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms import random imagenet_templates_smallest = [ 'a photo of a {}', ] imagenet_templates_small = [ 'a photo of a {}', 'a rendering of a {}', 'a cropped photo of the {}', 'the photo of a {}', 'a photo of a clean {}', 'a photo of a dirty {}', 'a dark photo of the {}', 'a photo of my {}', 'a photo of the cool {}', 'a close-up photo of a {}', 'a bright photo of the {}', 'a cropped photo of a {}', 'a photo of the {}', 'a good photo of the {}', 'a photo of one {}', 'a close-up photo of the {}', 'a rendition of the {}', 'a photo of the clean {}', 'a rendition of a {}', 'a photo of a nice {}', 'a good photo of a {}', 'a photo of the nice {}', 'a photo of the small {}', 'a photo of the weird {}', 'a photo of the large {}', 'a photo of a cool {}', 'a photo of a small {}', ] imagenet_dual_templates_small = [ 'a photo of a {} with {}', 'a rendering of a {} with {}', 'a cropped photo of the {} with {}', 'the photo of a {} with {}', 'a photo of a clean {} with {}', 'a photo of a dirty {} with {}', 'a dark photo of the {} with {}', 'a photo of my {} with {}', 'a photo of the cool {} with {}', 'a close-up photo of a {} with {}', 'a bright photo of the {} with {}', 'a cropped photo of a {} with {}', 'a photo of the {} with {}', 'a good photo of the {} with {}', 'a photo of one {} with {}', 'a close-up photo of the {} with {}', 'a rendition of the {} with {}', 'a photo of the clean {} with {}', 'a rendition of a {} with {}', 'a photo of a nice {} with {}', 'a good photo of a {} with {}', 'a photo of the nice {} with {}', 'a photo of the small {} with {}', 'a photo of the weird {} with {}', 'a photo of the large {} with {}', 'a photo of a cool {} with {}', 'a photo of a small {} with {}', ] per_img_token_list = [ 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', ] class PersonalizedBase(Dataset): def __init__( self, data_root, size=None, repeats=100, interpolation='bicubic', flip_p=0.5, set='train', placeholder_token='*', per_image_tokens=False, center_crop=False, mixing_prob=0.25, coarse_class_text=None, ): self.data_root = data_root self.image_paths = [ os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) ] # self._length = len(self.image_paths) self.num_images = len(self.image_paths) self._length = self.num_images self.placeholder_token = placeholder_token self.per_image_tokens = per_image_tokens self.center_crop = center_crop self.mixing_prob = mixing_prob self.coarse_class_text = coarse_class_text if per_image_tokens: assert self.num_images < len( per_img_token_list ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." if set == 'train': self._length = self.num_images * repeats self.size = size self.interpolation = { 'linear': PIL.Image.LINEAR, 'bilinear': PIL.Image.BILINEAR, 'bicubic': PIL.Image.BICUBIC, 'lanczos': PIL.Image.LANCZOS, }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): return self._length def __getitem__(self, i): example = {} image = Image.open(self.image_paths[i % self.num_images]) if not image.mode == 'RGB': image = image.convert('RGB') placeholder_string = self.placeholder_token if self.coarse_class_text: placeholder_string = ( f'{self.coarse_class_text} {placeholder_string}' ) if self.per_image_tokens and np.random.uniform() < self.mixing_prob: text = random.choice(imagenet_dual_templates_small).format( placeholder_string, per_img_token_list[i % self.num_images] ) else: text = random.choice(imagenet_templates_small).format( placeholder_string ) example['caption'] = text # default to score-sde preprocessing img = np.array(image).astype(np.uint8) if self.center_crop: crop = min(img.shape[0], img.shape[1]) h, w, = ( img.shape[0], img.shape[1], ) img = img[ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2, ] image = Image.fromarray(img) if self.size is not None: image = image.resize( (self.size, self.size), resample=self.interpolation ) image = self.flip(image) image = np.array(image).astype(np.uint8) example['image'] = (image / 127.5 - 1.0).astype(np.float32) return example ================================================ FILE: src/stablediffusion/ldm/data/personalized_style.py ================================================ import os import numpy as np import PIL from PIL import Image from torch.utils.data import Dataset from torchvision import transforms import random imagenet_templates_small = [ 'a painting in the style of {}', 'a rendering in the style of {}', 'a cropped painting in the style of {}', 'the painting in the style of {}', 'a clean painting in the style of {}', 'a dirty painting in the style of {}', 'a dark painting in the style of {}', 'a picture in the style of {}', 'a cool painting in the style of {}', 'a close-up painting in the style of {}', 'a bright painting in the style of {}', 'a cropped painting in the style of {}', 'a good painting in the style of {}', 'a close-up painting in the style of {}', 'a rendition in the style of {}', 'a nice painting in the style of {}', 'a small painting in the style of {}', 'a weird painting in the style of {}', 'a large painting in the style of {}', ] imagenet_dual_templates_small = [ 'a painting in the style of {} with {}', 'a rendering in the style of {} with {}', 'a cropped painting in the style of {} with {}', 'the painting in the style of {} with {}', 'a clean painting in the style of {} with {}', 'a dirty painting in the style of {} with {}', 'a dark painting in the style of {} with {}', 'a cool painting in the style of {} with {}', 'a close-up painting in the style of {} with {}', 'a bright painting in the style of {} with {}', 'a cropped painting in the style of {} with {}', 'a good painting in the style of {} with {}', 'a painting of one {} in the style of {}', 'a nice painting in the style of {} with {}', 'a small painting in the style of {} with {}', 'a weird painting in the style of {} with {}', 'a large painting in the style of {} with {}', ] per_img_token_list = [ 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', ] class PersonalizedBase(Dataset): def __init__( self, data_root, size=None, repeats=100, interpolation='bicubic', flip_p=0.5, set='train', placeholder_token='*', per_image_tokens=False, center_crop=False, ): self.data_root = data_root self.image_paths = [ os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root) ] # self._length = len(self.image_paths) self.num_images = len(self.image_paths) self._length = self.num_images self.placeholder_token = placeholder_token self.per_image_tokens = per_image_tokens self.center_crop = center_crop if per_image_tokens: assert self.num_images < len( per_img_token_list ), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." if set == 'train': self._length = self.num_images * repeats self.size = size self.interpolation = { 'linear': PIL.Image.LINEAR, 'bilinear': PIL.Image.BILINEAR, 'bicubic': PIL.Image.BICUBIC, 'lanczos': PIL.Image.LANCZOS, }[interpolation] self.flip = transforms.RandomHorizontalFlip(p=flip_p) def __len__(self): return self._length def __getitem__(self, i): example = {} image = Image.open(self.image_paths[i % self.num_images]) if not image.mode == 'RGB': image = image.convert('RGB') if self.per_image_tokens and np.random.uniform() < 0.25: text = random.choice(imagenet_dual_templates_small).format( self.placeholder_token, per_img_token_list[i % self.num_images] ) else: text = random.choice(imagenet_templates_small).format( self.placeholder_token ) example['caption'] = text # default to score-sde preprocessing img = np.array(image).astype(np.uint8) if self.center_crop: crop = min(img.shape[0], img.shape[1]) h, w, = ( img.shape[0], img.shape[1], ) img = img[ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2, ] image = Image.fromarray(img) if self.size is not None: image = image.resize( (self.size, self.size), resample=self.interpolation ) image = self.flip(image) image = np.array(image).astype(np.uint8) example['image'] = (image / 127.5 - 1.0).astype(np.float32) return example ================================================ FILE: src/stablediffusion/ldm/dream/conditioning.py ================================================ ''' This module handles the generation of the conditioning tensors, including management of weighted subprompts. Useful function exports: get_uc_and_c() get the conditioned and unconditioned latent split_weighted_subpromopts() split subprompts, normalize and weight them log_tokenization() print out colour-coded tokens and warn if truncated ''' import re import torch def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): uc = model.get_learned_conditioning(['']) # get weighted sub-prompts weighted_subprompts = split_weighted_subprompts( prompt, skip_normalize ) if len(weighted_subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # normalize each "sub prompt" and add it for subprompt, weight in weighted_subprompts: log_tokenization(subprompt, model, log_tokens) c = torch.add( c, model.get_learned_conditioning([subprompt]), alpha=weight, ) else: # just standard 1 prompt log_tokenization(prompt, model, log_tokens) c = model.get_learned_conditioning([prompt]) return (uc, c) def split_weighted_subprompts(text, skip_normalize=False)->list: """ grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining """ prompt_parser = re.compile(""" (?P # capture group for 'prompt' (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' ) # end 'prompt' (?: # non-capture group :+ # match one or more ':' characters (?P # capture group for 'weight' -?\d+(?:\.\d+)? # match positive or negative integer or decimal number )? # end weight capture group, make optional \s* # strip spaces after weight | # OR $ # else, if no ':' then match end of line ) # end non-capture group """, re.VERBOSE) parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] if skip_normalize: return parsed_prompts weight_sum = sum(map(lambda x: x[1], parsed_prompts)) if weight_sum == 0: print( "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") equal_weight = 1 / len(parsed_prompts) return [(x[0], equal_weight) for x in parsed_prompts] return [(x[0], x[1] / weight_sum) for x in parsed_prompts] # shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' def log_tokenization(text, model, log=False): if not log: return tokens = model.cond_stage_model.tokenizer._tokenize(text) tokenized = "" discarded = "" usedTokens = 0 totalTokens = len(tokens) for i in range(0, totalTokens): token = tokens[i].replace('', ' ') # alternate color s = (usedTokens % 6) + 1 if i < model.cond_stage_model.max_length: tokenized = tokenized + f"\x1b[0;3{s};40m{token}" usedTokens += 1 else: # over max token length discarded = discarded + f"\x1b[0;3{s};40m{token}" print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m") if discarded != "": print( f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" ) ================================================ FILE: src/stablediffusion/ldm/dream/devices.py ================================================ import torch from torch import autocast from contextlib import contextmanager, nullcontext def choose_torch_device() -> str: '''Convenience routine for guessing which GPU device to run model on''' if torch.cuda.is_available(): return 'cuda' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): return 'mps' return 'cpu' def choose_autocast_device(device): '''Returns an autocast compatible device from a torch device''' device_type = device.type # this returns 'mps' on M1 # autocast only supports cuda or cpu if device_type in ('cuda','cpu'): return device_type,autocast else: return 'cpu',nullcontext ================================================ FILE: src/stablediffusion/ldm/dream/generator/__init__.py ================================================ ''' Initialization file for the ldm.dream.generator package ''' from .base import Generator ================================================ FILE: src/stablediffusion/ldm/dream/generator/base.py ================================================ ''' Base class for ldm.dream.generator.* including img2img, txt2img, and inpaint ''' import torch import numpy as np import random from tqdm import tqdm, trange from PIL import Image from einops import rearrange, repeat from pytorch_lightning import seed_everything from src.stablediffusion.ldm.dream.devices import choose_autocast_device downsampling = 8 class Generator(): def __init__(self,model): self.model = model self.seed = None self.latent_channels = model.channels self.downsampling_factor = downsampling # BUG: should come from model or config self.variation_amount = 0 self.with_variations = [] # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it """ raise NotImplementedError("image_iterator() must be implemented in a descendent class") def set_variation(self, seed, variation_amount, with_variations): self.seed = seed self.variation_amount = variation_amount self.with_variations = with_variations def generate(self,prompt,init_image,width,height,iterations=1,seed=None, image_callback=None, step_callback=None, **kwargs): device_type,scope = choose_autocast_device(self.model.device) make_image = self.get_make_image( prompt, init_image = init_image, width = width, height = height, step_callback = step_callback, **kwargs ) results = [] seed = seed if seed else self.new_seed() seed, initial_noise = self.generate_initial_noise(seed, width, height) with scope(device_type), self.model.ema_scope(): for n in trange(iterations, desc='Generating'): x_T = None if self.variation_amount > 0: seed_everything(seed) target_noise = self.get_noise(width,height) x_T = self.slerp(self.variation_amount, initial_noise, target_noise) elif initial_noise is not None: # i.e. we specified particular variations x_T = initial_noise else: seed_everything(seed) if self.model.device.type == 'mps': x_T = self.get_noise(width,height) # make_image will do the equivalent of get_noise itself image = make_image(x_T) results.append([image, seed]) if image_callback is not None: image_callback(image, seed) seed = self.new_seed() return results def sample_to_image(self,samples): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it """ x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) if len(x_samples) != 1: raise Exception( f'>> expected to get a single image, but got {len(x_samples)}') x_sample = 255.0 * rearrange( x_samples[0].cpu().numpy(), 'c h w -> h w c' ) return Image.fromarray(x_sample.astype(np.uint8)) def generate_initial_noise(self, seed, width, height): initial_noise = None if self.variation_amount > 0 or len(self.with_variations) > 0: # use fixed initial noise plus random noise per iteration seed_everything(seed) initial_noise = self.get_noise(width,height) for v_seed, v_weight in self.with_variations: seed = v_seed seed_everything(seed) next_noise = self.get_noise(width,height) initial_noise = self.slerp(v_weight, initial_noise, next_noise) if self.variation_amount > 0: random.seed() # reset RNG to an actually random state, so we can get a random seed for variations seed = random.randrange(0,np.iinfo(np.uint32).max) return (seed, initial_noise) else: return (seed, None) # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height): """ Returns a tensor filled with random numbers, either form a normal distribution (txt2img) or from the latent image (img2img, inpaint) """ raise NotImplementedError("get_noise() must be implemented in a descendent class") def new_seed(self): self.seed = random.randrange(0, np.iinfo(np.uint32).max) return self.seed def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): ''' Spherical linear interpolation Args: t (float/np.ndarray): Float value between 0.0 and 1.0 v0 (np.ndarray): Starting vector v1 (np.ndarray): Final vector DOT_THRESHOLD (float): Threshold for considering the two vectors as colineal. Not recommended to alter this. Returns: v2 (np.ndarray): Interpolation vector between v0 and v1 ''' inputs_are_torch = False if not isinstance(v0, np.ndarray): inputs_are_torch = True v0 = v0.detach().cpu().numpy() if not isinstance(v1, np.ndarray): inputs_are_torch = True v1 = v1.detach().cpu().numpy() dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) if np.abs(dot) > DOT_THRESHOLD: v2 = (1 - t) * v0 + t * v1 else: theta_0 = np.arccos(dot) sin_theta_0 = np.sin(theta_0) theta_t = theta_0 * t sin_theta_t = np.sin(theta_t) s0 = np.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 v2 = s0 * v0 + s1 * v1 if inputs_are_torch: v2 = torch.from_numpy(v2).to(self.model.device) return v2 ================================================ FILE: src/stablediffusion/ldm/dream/generator/img2img.py ================================================ ''' ldm.dream.generator.txt2img descends from src.stablediffusion.ldm.dream.generator ''' import torch import numpy as np from src.stablediffusion.ldm.dream.devices import choose_autocast_device from src.stablediffusion.ldm.dream.generator.base import Generator from src.stablediffusion.ldm.models.diffusion.ddim import DDIMSampler class Img2Img(Generator): def __init__(self,model): super().__init__(model) self.init_latent = None # by get_noise() @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,strength,step_callback=None,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it. """ # PLMS sampler not supported yet, so ignore previous sampler if not isinstance(sampler,DDIMSampler): print( f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" ) sampler = DDIMSampler(self.model, device=self.model.device) sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) device_type,scope = choose_autocast_device(self.model.device) with scope(device_type): self.init_latent = self.model.get_first_stage_encoding( self.model.encode_first_stage(init_image) ) # move to latent space t_enc = int(strength * steps) uc, c = conditioning @torch.no_grad() def make_image(x_T): # encode (scaled latent) z_enc = sampler.stochastic_encode( self.init_latent, torch.tensor([t_enc]).to(self.model.device), noise=x_T ) # decode it samples = sampler.decode( z_enc, c, t_enc, img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, ) return self.sample_to_image(samples) return make_image def get_noise(self,width,height): device = self.model.device init_latent = self.init_latent assert init_latent is not None,'call to get_noise() when init_latent not set' if device.type == 'mps': return torch.randn_like(init_latent, device='cpu').to(device) else: return torch.randn_like(init_latent, device=device) ================================================ FILE: src/stablediffusion/ldm/dream/generator/inpaint.py ================================================ ''' ldm.dream.generator.inpaint descends from src.stablediffusion.ldm.dream.generator ''' import torch import numpy as np from einops import rearrange, repeat from src.stablediffusion.ldm.dream.devices import choose_autocast_device from src.stablediffusion.ldm.dream.generator.img2img import Img2Img from src.stablediffusion.ldm.models.diffusion.ddim import DDIMSampler class Inpaint(Img2Img): def __init__(self,model): self.init_latent = None super().__init__(model) @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,init_image,mask_image,strength, step_callback=None,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image + mask. Return value depends on the seed at the time you call it. kwargs are 'init_latent' and 'strength' """ mask_image = mask_image[0][0].unsqueeze(0).repeat(4,1,1).unsqueeze(0) mask_image = repeat(mask_image, '1 ... -> b ...', b=1) # PLMS sampler not supported yet, so ignore previous sampler if not isinstance(sampler,DDIMSampler): print( f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler" ) sampler = DDIMSampler(self.model, device=self.model.device) sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) device_type,scope = choose_autocast_device(self.model.device) with scope(device_type): self.init_latent = self.model.get_first_stage_encoding( self.model.encode_first_stage(init_image) ) # move to latent space t_enc = int(strength * steps) uc, c = conditioning print(f">> target t_enc is {t_enc} steps") @torch.no_grad() def make_image(x_T): # encode (scaled latent) z_enc = sampler.stochastic_encode( self.init_latent, torch.tensor([t_enc]).to(self.model.device), noise=x_T ) # decode it samples = sampler.decode( z_enc, c, t_enc, img_callback = step_callback, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, mask = mask_image, init_latent = self.init_latent ) return self.sample_to_image(samples) return make_image ================================================ FILE: src/stablediffusion/ldm/dream/generator/txt2img.py ================================================ ''' ldm.dream.generator.txt2img inherits from src.stablediffusion.ldm.dream.generator ''' import torch import numpy as np from src.stablediffusion.ldm.dream.generator.base import Generator class Txt2Img(Generator): def __init__(self,model): super().__init__(model) @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, conditioning,width,height,step_callback=None,**kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ uc, c = conditioning @torch.no_grad() def make_image(x_T): shape = [ self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor, ] samples, _ = sampler.sample( batch_size = 1, S = steps, x_T = x_T, conditioning = c, shape = shape, verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, eta = ddim_eta, img_callback = step_callback ) return self.sample_to_image(samples) return make_image # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height): device = self.model.device if device.type == 'mps': return torch.randn([1, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], device='cpu').to(device) else: return torch.randn([1, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], device=device) ================================================ FILE: src/stablediffusion/ldm/dream/image_util.py ================================================ from math import sqrt, floor, ceil from PIL import Image class InitImageResizer(): """Simple class to create resized copies of an Image while preserving the aspect ratio.""" def __init__(self,Image): self.image = Image def resize(self,width=None,height=None) -> Image: """ Return a copy of the image resized to fit within a box width x height. The aspect ratio is maintained. If neither width nor height are provided, then returns a copy of the original image. If one or the other is provided, then the other will be calculated from the aspect ratio. Everything is floored to the nearest multiple of 64 so that it can be passed to img2img() """ im = self.image ar = im.width/float(im.height) # Infer missing values from aspect ratio if not(width or height): # both missing width = im.width height = im.height elif not height: # height missing height = int(width/ar) elif not width: # width missing width = int(height*ar) # rw and rh are the resizing width and height for the image # they maintain the aspect ratio, but may not completelyl fill up # the requested destination size (rw,rh) = (width,int(width/ar)) if im.width>=im.height else (int(height*ar),height) #round everything to multiples of 64 width,height,rw,rh = map( lambda x: x-x%64, (width,height,rw,rh) ) # no resize necessary, but return a copy if im.width == width and im.height == height: return im.copy() # otherwise resize the original image so that it fits inside the bounding box resized_image = self.image.resize((rw,rh),resample=Image.Resampling.LANCZOS) return resized_image def make_grid(image_list, rows=None, cols=None): image_cnt = len(image_list) if None in (rows, cols): rows = floor(sqrt(image_cnt)) # try to make it square cols = ceil(image_cnt / rows) width = image_list[0].width height = image_list[0].height grid_img = Image.new('RGB', (width * cols, height * rows)) i = 0 for r in range(0, rows): for c in range(0, cols): if i >= len(image_list): break grid_img.paste(image_list[i], (c * width, r * height)) i = i + 1 return grid_img ================================================ FILE: src/stablediffusion/ldm/dream/pngwriter.py ================================================ """ Two helper classes for dealing with PNG images and their path names. PngWriter -- Converts Images generated by T2I into PNGs, finds appropriate names for them, and writes prompt metadata into the PNG. PromptFormatter -- Utility for converting a Namespace of prompt parameters back into a formatted prompt string with command-line switches. """ import os import re from PIL import PngImagePlugin # -------------------image generation utils----- class PngWriter: def __init__(self, outdir): self.outdir = outdir os.makedirs(outdir, exist_ok=True) # gives the next unique prefix in outdir def unique_prefix(self): # sort reverse alphabetically until we find max+1 dirlist = sorted(os.listdir(self.outdir), reverse=True) # find the first filename that matches our pattern or return 000000.0.png existing_name = next( (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), '0000000.0.png', ) basecount = int(existing_name.split('.', 1)[0]) + 1 return f'{basecount:06}' # saves image named _image_ to outdir/name, writing metadata from prompt # returns full path of output def save_image_and_prompt_to_png(self, image, prompt, name): path = os.path.join(self.outdir, name) info = PngImagePlugin.PngInfo() info.add_text('Dream', prompt) image.save(path, 'PNG', pnginfo=info) return path class PromptFormatter: def __init__(self, t2i, opt): self.t2i = t2i self.opt = opt # note: the t2i object should provide all these values. # there should be no need to or against opt values def normalize_prompt(self): """Normalize the prompt and switches""" t2i = self.t2i opt = self.opt switches = list() switches.append(f'"{opt.prompt}"') switches.append(f'-s{opt.steps or t2i.steps}') switches.append(f'-W{opt.width or t2i.width}') switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') switches.append(f'-A{opt.sampler_name or t2i.sampler_name}') # to do: put model name into the t2i object # switches.append(f'--model{t2i.model_name}') if opt.seamless or t2i.seamless: switches.append(f'--seamless') if opt.init_img: switches.append(f'-I{opt.init_img}') if opt.fit: switches.append(f'--fit') if opt.strength and opt.init_img is not None: switches.append(f'-f{opt.strength or t2i.strength}') if opt.gfpgan_strength: switches.append(f'-G{opt.gfpgan_strength}') if opt.upscale: switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}') if opt.variation_amount > 0: switches.append(f'-v{opt.variation_amount}') if opt.with_variations: formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations) switches.append(f'-V{formatted_variations}') return ' '.join(switches) ================================================ FILE: src/stablediffusion/ldm/dream/readline.py ================================================ """ Readline helper functions for dream.py (linux and mac only). """ import os import re import atexit # ---------------readline utilities--------------------- try: import readline readline_available = True except: readline_available = False class Completer: def __init__(self, options): self.options = sorted(options) return def complete(self, text, state): buffer = readline.get_line_buffer() if text.startswith(('-I', '--init_img','-M','--init_mask')): return self._path_completions(text, state, ('.png','.jpg','.jpeg')) if buffer.strip().endswith('cd') or text.startswith(('.', '/')): return self._path_completions(text, state, ()) response = None if state == 0: # This is the first time for this text, so build a match list. if text: self.matches = [ s for s in self.options if s and s.startswith(text) ] else: self.matches = self.options[:] # Return the state'th item from the match list, # if we have that many. try: response = self.matches[state] except IndexError: response = None return response def _path_completions(self, text, state, extensions): # get the path so far # TODO: replace this mess with a regular expression match if text.startswith('-I'): path = text.replace('-I', '', 1).lstrip() elif text.startswith('--init_img='): path = text.replace('--init_img=', '', 1).lstrip() elif text.startswith('--init_mask='): path = text.replace('--init_mask=', '', 1).lstrip() elif text.startswith('-M'): path = text.replace('-M', '', 1).lstrip() else: path = text matches = list() path = os.path.expanduser(path) if len(path) == 0: matches.append(text + './') else: dir = os.path.dirname(path) dir_list = os.listdir(dir) for n in dir_list: if n.startswith('.') and len(n) > 1: continue full_path = os.path.join(dir, n) if full_path.startswith(path): if os.path.isdir(full_path): matches.append( os.path.join(os.path.dirname(text), n) + '/' ) elif n.endswith(extensions): matches.append(os.path.join(os.path.dirname(text), n)) try: response = matches[state] except IndexError: response = None return response if readline_available: readline.set_completer( Completer( [ '--steps','-s', '--seed','-S', '--iterations','-n', '--width','-W','--height','-H', '--cfg_scale','-C', '--grid','-g', '--individual','-i', '--init_img','-I', '--init_mask','-M', '--strength','-f', '--variants','-v', '--outdir','-o', '--sampler','-A','-m', '--embedding_path', '--device', '--grid','-g', '--gfpgan_strength','-G', '--upscale','-U', '-save_orig','--save_original', '--skip_normalize','-x', '--log_tokenization','t', ] ).complete ) readline.set_completer_delims(' ') readline.parse_and_bind('tab: complete') histfile = os.path.join(os.path.expanduser('~'), '.dream_history') try: readline.read_history_file(histfile) readline.set_history_length(1000) except FileNotFoundError: pass atexit.register(readline.write_history_file, histfile) ================================================ FILE: src/stablediffusion/ldm/dream/server.py ================================================ import argparse import json import base64 import mimetypes import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from src.stablediffusion.ldm.dream.pngwriter import PngWriter, PromptFormatter from threading import Event def build_opt(post_data, seed, gfpgan_model_exists): opt = argparse.Namespace() setattr(opt, 'prompt', post_data['prompt']) setattr(opt, 'init_img', post_data['initimg']) setattr(opt, 'strength', float(post_data['strength'])) setattr(opt, 'iterations', int(post_data['iterations'])) setattr(opt, 'steps', int(post_data['steps'])) setattr(opt, 'width', int(post_data['width'])) setattr(opt, 'height', int(post_data['height'])) setattr(opt, 'seamless', 'seamless' in post_data) setattr(opt, 'fit', 'fit' in post_data) setattr(opt, 'mask', 'mask' in post_data) setattr(opt, 'invert_mask', 'invert_mask' in post_data) setattr(opt, 'cfg_scale', float(post_data['cfg_scale'])) setattr(opt, 'sampler_name', post_data['sampler_name']) setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0) setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None) setattr(opt, 'progress_images', 'progress_images' in post_data) setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed'])) setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0) setattr(opt, 'with_variations', []) broken = False if int(post_data['seed']) != -1 and post_data['with_variations'] != '': for part in post_data['with_variations'].split(','): seed_and_weight = part.split(':') if len(seed_and_weight) != 2: print(f'could not parse with_variation part "{part}"') broken = True break try: seed = int(seed_and_weight[0]) weight = float(seed_and_weight[1]) except ValueError: print(f'could not parse with_variation part "{part}"') broken = True break opt.with_variations.append([seed, weight]) if broken: raise CanceledException if len(opt.with_variations) == 0: opt.with_variations = None return opt class CanceledException(Exception): pass class DreamServer(BaseHTTPRequestHandler): model = None outdir = None canceled = Event() def do_GET(self): if self.path == "/": self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() with open("./static/dream_web/index.html", "rb") as content: self.wfile.write(content.read()) elif self.path == "/config.js": # unfortunately this import can't be at the top level, since that would cause a circular import from src.stablediffusion.ldm.gfpgan.gfpgan_tools import gfpgan_model_exists self.send_response(200) self.send_header("Content-type", "application/javascript") self.end_headers() config = { 'gfpgan_model_exists': gfpgan_model_exists } self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8")) elif self.path == "/run_log.json": self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() output = [] log_file = os.path.join(self.outdir, "dream_web_log.txt") if os.path.exists(log_file): with open(log_file, "r") as log: for line in log: url, config = line.split(": {", maxsplit=1) config = json.loads("{" + config) config["url"] = url.lstrip(".") if os.path.exists(url): output.append(config) self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8")) elif self.path == "/cancel": self.canceled.set() self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() self.wfile.write(bytes('{}', 'utf8')) else: path = "." + self.path cwd = os.path.realpath(os.getcwd()) is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd if not (is_in_cwd and os.path.exists(path)): self.send_response(404) return mime_type = mimetypes.guess_type(path)[0] if mime_type is not None: self.send_response(200) self.send_header("Content-type", mime_type) self.end_headers() with open("." + self.path, "rb") as content: self.wfile.write(content.read()) else: self.send_response(404) def do_POST(self): self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() # unfortunately this import can't be at the top level, since that would cause a circular import from src.stablediffusion.ldm.gfpgan.gfpgan_tools import gfpgan_model_exists content_length = int(self.headers['Content-Length']) post_data = json.loads(self.rfile.read(content_length)) opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) self.canceled.clear() print(f">> Request to generate with prompt: {opt.prompt}") # In order to handle upscaled images, the PngWriter needs to maintain state # across images generated by each call to prompt2img(), so we define it in # the outer scope of image_done() config = post_data.copy() # Shallow copy config['initimg'] = config.pop('initimg_name', '') images_generated = 0 # helps keep track of when upscaling is started images_upscaled = 0 # helps keep track of when upscaling is completed pngwriter = PngWriter(self.outdir) prefix = pngwriter.unique_prefix() # if upscaling is requested, then this will be called twice, once when # the images are first generated, and then again when after upscaling # is complete. The upscaling replaces the original file, so the second # entry should not be inserted into the image list. def image_done(image, seed, upscaled=False): name = f'{prefix}.{seed}.png' iter_opt = argparse.Namespace(**vars(opt)) # copy if opt.variation_amount > 0: this_variation = [[seed, opt.variation_amount]] if opt.with_variations is None: iter_opt.with_variations = this_variation else: iter_opt.with_variations = opt.with_variations + this_variation iter_opt.variation_amount = 0 elif opt.with_variations is None: iter_opt.seed = seed normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt() path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name) if int(config['seed']) == -1: config['seed'] = seed # Append post_data to log, but only once! if not upscaled: with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log: log.write(f"{path}: {json.dumps(config)}\n") self.wfile.write(bytes(json.dumps( {'event': 'result', 'url': path, 'seed': seed, 'config': config} ) + '\n',"utf-8")) # control state of the "postprocessing..." message upscaling_requested = opt.upscale or opt.gfpgan_strength > 0 nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure. nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure. if upscaled: images_upscaled += 1 else: images_generated += 1 if upscaling_requested: action = None if images_generated >= opt.iterations: if images_upscaled < opt.iterations: action = 'upscaling-started' else: action = 'upscaling-done' if action: x = images_upscaled + 1 self.wfile.write(bytes(json.dumps( {'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'} ) + '\n',"utf-8")) step_writer = PngWriter(os.path.join(self.outdir, "intermediates")) step_index = 1 def image_progress(sample, step): if self.canceled.is_set(): self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8')) raise CanceledException path = None # since rendering images is moderately expensive, only render every 5th image # and don't bother with the last one, since it'll render anyway nonlocal step_index if opt.progress_images and step % 5 == 0 and step < opt.steps - 1: image = self.model.sample_to_image(sample) name = f'{prefix}.{opt.seed}.{step_index}.png' metadata = f'{opt.prompt} -S{opt.seed} [intermediate]' path = step_writer.save_image_and_prompt_to_png(image, metadata, name) step_index += 1 self.wfile.write(bytes(json.dumps( {'event': 'step', 'step': step + 1, 'url': path} ) + '\n',"utf-8")) try: if opt.init_img is None: # Run txt2img self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: initimg = opt.init_img.split(",")[1] # Ignore mime type f.write(base64.b64decode(initimg)) opt1 = argparse.Namespace(**vars(opt)) opt1.init_img = "./img2img-tmp.png" try: # Run img2img self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done) finally: # Remove the temp file os.remove("./img2img-tmp.png") except CanceledException: print(f"Canceled.") return class ThreadingDreamServer(ThreadingHTTPServer): def __init__(self, server_address): super(ThreadingDreamServer, self).__init__(server_address, DreamServer) ================================================ FILE: src/stablediffusion/ldm/generate.py ================================================ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) # Derived from source code carrying the following copyrights # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors import torch import numpy as np import random import os import time import re import sys import traceback import transformers from omegaconf import OmegaConf from PIL import Image, ImageOps from torch import nn from pytorch_lightning import seed_everything from src.stablediffusion.ldm.util import instantiate_from_config from src.stablediffusion.ldm.models.diffusion.ddim import DDIMSampler from src.stablediffusion.ldm.models.diffusion.plms import PLMSSampler from src.stablediffusion.ldm.models.diffusion.ksampler import KSampler from src.stablediffusion.ldm.dream.pngwriter import PngWriter from src.stablediffusion.ldm.dream.image_util import InitImageResizer from src.stablediffusion.ldm.dream.devices import choose_torch_device from src.stablediffusion.ldm.dream.conditioning import get_uc_and_c """Simplified text to image API for stable diffusion/latent diffusion Example Usage: from src.stablediffusion.ldm.generate import Generate # Create an object with default values gr = Generate() # do the slow model initialization gr.load_model() # Do the fast inference & image generation. Any options passed here # override the default values assigned during class initialization # Will call load_model() if the model was not previously loaded and so # may be slow at first. # The method returns a list of images. Each row of the list is a sub-list of [filename,seed] results = gr.prompt2png(prompt = "an astronaut riding a horse", outdir = "./outputs/samples", iterations = 3) for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') # Same thing, but using an initial image. results = gr.prompt2png(prompt = "an astronaut riding a horse", outdir = "./outputs/, iterations = 3, init_img = "./sketches/horse+rider.png") for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') # Same thing, but we return a series of Image objects, which lets you manipulate them, # combine them, and save them under arbitrary names results = gr.prompt2image(prompt = "an astronaut riding a horse" outdir = "./outputs/") for row in results: im = row[0] seed = row[1] im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png') im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg') Note that the old txt2img() and img2img() calls are deprecated but will still work. The full list of arguments to Generate() are: gr = Generate( weights = path to model weights ('models/ldm/stable-diffusion-v1/model.ckpt') config = path to model configuraiton ('configs/stable-diffusion/v1-inference.yaml') iterations = // how many times to run the sampling (1) steps = // 50 seed = // current system time sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms grid = // false width = // image width, multiple of 64 (512) height = // image height, multiple of 64 (512) cfg_scale = // condition-free guidance scale (7.5) ) """ class Generate: """Generate class Stores default values for multiple configuration items """ def __init__( self, iterations = 1, steps = 50, cfg_scale = 7.5, weights = 'models/ldm/stable-diffusion-v1/model.ckpt', config = 'configs/stable-diffusion/v1-inference.yaml', grid = False, width = 512, height = 512, sampler_name = 'k_lms', ddim_eta = 0.0, # deterministic precision = 'autocast', full_precision = False, strength = 0.75, # default in scripts/img2img.py seamless = False, embedding_path = None, device_type = 'cuda', ignore_ctrl_c = False, ): self.iterations = iterations self.width = width self.height = height self.steps = steps self.cfg_scale = cfg_scale self.weights = weights self.config = config self.sampler_name = sampler_name self.grid = grid self.ddim_eta = ddim_eta self.precision = precision self.full_precision = True if choose_torch_device() == 'mps' else full_precision self.strength = strength self.seamless = seamless self.embedding_path = embedding_path self.device_type = device_type self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here... self.model = None # empty for now self.sampler = None self.device = None self.generators = {} self.base_generator = None self.seed = None if device_type == 'cuda' and not torch.cuda.is_available(): device_type = choose_torch_device() print(">> cuda not available, using device", device_type) self.device = torch.device(device_type) # for VRAM usage statistics device_type = choose_torch_device() self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None transformers.logging.set_verbosity_error() def prompt2png(self, prompt, outdir, **kwargs): """ Takes a prompt and an output directory, writes out the requested number of PNG files, and returns an array of [[filename,seed],[filename,seed]...] Optional named arguments are the same as those passed to Generate and prompt2image() """ results = self.prompt2image(prompt, **kwargs) pngwriter = PngWriter(outdir) prefix = pngwriter.unique_prefix() outputs = [] for image, seed in results: name = f'{prefix}.{seed}.png' path = pngwriter.save_image_and_prompt_to_png( image, f'{prompt} -S{seed}', name) outputs.append([path, seed]) return outputs def txt2img(self, prompt, **kwargs): outdir = kwargs.pop('outdir', 'outputs/img-samples') return self.prompt2png(prompt, outdir, **kwargs) def img2img(self, prompt, **kwargs): outdir = kwargs.pop('outdir', 'outputs/img-samples') assert ( 'init_img' in kwargs ), 'call to img2img() must include the init_img argument' return self.prompt2png(prompt, outdir, **kwargs) def prompt2image( self, # these are common prompt, iterations = None, steps = None, seed = None, cfg_scale = None, ddim_eta = None, skip_normalize = False, image_callback = None, step_callback = None, width = None, height = None, sampler_name = None, seamless = False, log_tokenization= False, with_variations = None, variation_amount = 0.0, # these are specific to img2img and inpaint init_img = None, init_mask = None, fit = False, strength = None, # these are specific to GFPGAN/ESRGAN gfpgan_strength= 0, save_original = False, upscale = None, **args, ): # eat up additional cruft """ ldm.generate.prompt2image() is the common entry point for txt2img() and img2img() It takes the following arguments: prompt // prompt string (no default) iterations // iterations (1); image count=iterations steps // refinement steps per iteration seed // seed for random number generator width // width of image, in multiples of 64 (512) height // height of image, in multiples of 64 (512) cfg_scale // how strongly the prompt influences the image (7.5) (must be >1) seamless // whether the generated image should tile init_img // path to an initial image strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) step_callback // a function or method that will be called each step image_callback // a function or method that will be called each time an image is generated with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) To use the step callback, define a function that receives two arguments: - Image GPU data - The step number To use the image callback, define a function of method that receives two arguments, an Image object and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: def process_image(image,seed): image.save(f{'images/seed.png'}) The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code to create the requested output directory, select a unique informative name for each image, and write the prompt into the PNG metadata. """ # TODO: convert this into a getattr() loop steps = steps or self.steps width = width or self.width height = height or self.height seamless = seamless or self.seamless cfg_scale = cfg_scale or self.cfg_scale ddim_eta = ddim_eta or self.ddim_eta iterations = iterations or self.iterations strength = strength or self.strength self.seed = seed self.log_tokenization = log_tokenization with_variations = [] if with_variations is None else with_variations model = ( self.load_model() ) # will instantiate the model or return it from cache for m in model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m.padding_mode = 'circular' if seamless else m._orig_padding_mode assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0' assert ( 0.0 < strength < 1.0 ), 'img2img and inpaint strength can only work with 0.0 < strength < 1.0' assert ( 0.0 <= variation_amount <= 1.0 ), '-v --variation_amount must be in [0.0, 1.0]' # check this logic - doesn't look right if len(with_variations) > 0 or variation_amount > 1.0: assert seed is not None,\ 'seed must be specified when using with_variations' if variation_amount == 0.0: assert iterations == 1,\ 'when using --with_variations, multiple iterations are only possible when using --variation_amount' assert all(0 <= weight <= 1 for _, weight in with_variations),\ f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' width, height, _ = self._resolution_check(width, height, log=True) if sampler_name and (sampler_name != self.sampler_name): self.sampler_name = sampler_name self._set_sampler() tic = time.time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() results = list() init_image = None mask_image = None try: uc, c = get_uc_and_c( prompt, model=self.model, skip_normalize=skip_normalize, log_tokens=self.log_tokenization ) (init_image,mask_image) = self._make_images(init_img,init_mask, width, height, fit) if (init_image is not None) and (mask_image is not None): generator = self._make_inpaint() elif init_image is not None: generator = self._make_img2img() else: generator = self._make_txt2img() generator.set_variation(self.seed, variation_amount, with_variations) results = generator.generate( prompt, iterations = iterations, seed = self.seed, sampler = self.sampler, steps = steps, cfg_scale = cfg_scale, conditioning = (uc,c), ddim_eta = ddim_eta, image_callback = image_callback, # called after the final image is generated step_callback = step_callback, # called after each intermediate image is generated width = width, height = height, init_image = init_image, # notice that init_image is different from init_img mask_image = mask_image, strength = strength, ) if upscale is not None or gfpgan_strength > 0: self.upscale_and_reconstruct(results, upscale = upscale, strength = gfpgan_strength, save_original = save_original, image_callback = image_callback) except KeyboardInterrupt: print('*interrupted*') if not self.ignore_ctrl_c: raise KeyboardInterrupt print( '>> Partial results will be returned; if --grid was requested, nothing will be returned.' ) except RuntimeError as e: print(traceback.format_exc(), file=sys.stderr) print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') print( f'>> {len(results)} image(s) generated in', '%4.2fs' % (toc - tic) ) if torch.cuda.is_available() and self.device.type == 'cuda': print( f'>> Max VRAM used for this generation:', '%4.2fG.' % (torch.cuda.max_memory_allocated() / 1e9), 'Current VRAM utilization:' '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), ) self.session_peakmem = max( self.session_peakmem, torch.cuda.max_memory_allocated() ) print( f'>> Max VRAM used since script start: ', '%4.2fG' % (self.session_peakmem / 1e9), ) return results def _make_images(self, img_path, mask_path, width, height, fit=False): init_image = None init_mask = None if not img_path: return None,None image = self._load_img(img_path, width, height, fit=fit) # this returns an Image init_image = self._create_init_image(image) # this returns a torch tensor if self._has_transparency(image) and not mask_path: # if image has a transparent area and no mask was provided, then try to generate mask print('>> Initial image has transparent areas. Will inpaint in these regions.') if self._check_for_erasure(image): print( '>> WARNING: Colors underneath the transparent region seem to have been erased.\n', '>> Inpainting will be suboptimal. Please preserve the colors when making\n', '>> a transparency mask, or provide mask explicitly using --init_mask (-M).' ) init_mask = self._create_init_mask(image) # this returns a torch tensor if mask_path: mask_image = self._load_img(mask_path, width, height, fit=fit) # this returns an Image init_mask = self._create_init_mask(mask_image) return init_image,init_mask def _make_img2img(self): if not self.generators.get('img2img'): from src.stablediffusion.ldm.dream.generator.img2img import Img2Img self.generators['img2img'] = Img2Img(self.model) return self.generators['img2img'] def _make_txt2img(self): if not self.generators.get('txt2img'): from src.stablediffusion.ldm.dream.generator.txt2img import Txt2Img self.generators['txt2img'] = Txt2Img(self.model) return self.generators['txt2img'] def _make_inpaint(self): if not self.generators.get('inpaint'): from src.stablediffusion.ldm.dream.generator.inpaint import Inpaint self.generators['inpaint'] = Inpaint(self.model) return self.generators['inpaint'] def load_model(self): """Load and initialize the model from configuration variables passed at object creation time""" if self.model is None: seed_everything(random.randrange(0, np.iinfo(np.uint32).max)) try: config = OmegaConf.load(self.config) model = self._load_model_from_config(config, self.weights) if self.embedding_path is not None: model.embedding_manager.load( self.embedding_path, self.full_precision ) self.model = model.to(self.device) # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here self.model.cond_stage_model.device = self.device except AttributeError as e: print(f'>> Error loading model. {str(e)}', file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) raise SystemExit from e self._set_sampler() for m in self.model.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): m._orig_padding_mode = m.padding_mode return self.model def upscale_and_reconstruct(self, image_list, upscale = None, strength = 0.0, save_original = False, image_callback = None): try: if upscale is not None: from src.stablediffusion.ldm.gfpgan.gfpgan_tools import real_esrgan_upscale if strength > 0: from src.stablediffusion.ldm.gfpgan.gfpgan_tools import run_gfpgan except (ModuleNotFoundError, ImportError): print(traceback.format_exc(), file=sys.stderr) print('>> You may need to install the ESRGAN and/or GFPGAN modules') return for r in image_list: image, seed = r try: if upscale is not None: if len(upscale) < 2: upscale.append(0.75) image = real_esrgan_upscale( image, upscale[1], int(upscale[0]), seed, ) if strength > 0: image = run_gfpgan( image, strength, seed, 1 ) except Exception as e: print( f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' ) if image_callback is not None: image_callback(image, seed, upscaled=True) else: r[0] = image # to help WebGUI - front end to generator util function def sample_to_image(self,samples): return self._sample_to_image(samples) def _sample_to_image(self,samples): if not self.base_generator: from src.stablediffusion.ldm.dream.generator import Generator self.base_generator = Generator(self.model) return self.base_generator.sample_to_image(samples) def _set_sampler(self): msg = f'>> Setting Sampler to {self.sampler_name}' if self.sampler_name == 'plms': self.sampler = PLMSSampler(self.model, device=self.device) elif self.sampler_name == 'ddim': self.sampler = DDIMSampler(self.model, device=self.device) elif self.sampler_name == 'k_dpm_2_a': self.sampler = KSampler( self.model, 'dpm_2_ancestral', device=self.device ) elif self.sampler_name == 'k_dpm_2': self.sampler = KSampler(self.model, 'dpm_2', device=self.device) elif self.sampler_name == 'k_euler_a': self.sampler = KSampler( self.model, 'euler_ancestral', device=self.device ) elif self.sampler_name == 'k_euler': self.sampler = KSampler(self.model, 'euler', device=self.device) elif self.sampler_name == 'k_heun': self.sampler = KSampler(self.model, 'heun', device=self.device) elif self.sampler_name == 'k_lms': self.sampler = KSampler(self.model, 'lms', device=self.device) else: msg = f'>> Unsupported Sampler: {self.sampler_name}, Defaulting to plms' self.sampler = PLMSSampler(self.model, device=self.device) print(msg) def _load_model_from_config(self, config, ckpt): print(f'>> Loading model from {ckpt}') # for usage statistics device_type = choose_torch_device() if device_type == 'cuda': torch.cuda.reset_peak_memory_stats() tic = time.time() # this does the work pl_sd = torch.load(ckpt, map_location='cpu') sd = pl_sd['state_dict'] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if self.full_precision: print( '>> Using slower but more accurate full-precision math (--full_precision)' ) else: print( '>> Using half precision math. Call with --full_precision to use more accurate but VRAM-intensive full precision.' ) model.half() model.to(self.device) model.eval() # usage statistics toc = time.time() print( f'>> Model loaded in', '%4.2fs' % (toc - tic) ) if device_type == 'cuda': print( '>> Max VRAM used to load the model:', '%4.2fG' % (torch.cuda.max_memory_allocated() / 1e9), '\n>> Current VRAM usage:' '%4.2fG' % (torch.cuda.memory_allocated() / 1e9), ) return model def _load_img(self, path, width, height, fit=False): assert os.path.exists(path), f'>> {path}: File not found' # with Image.open(path) as img: # image = img.convert('RGBA') image = Image.open(path) print( f'>> loaded input image of size {image.width}x{image.height} from {path}' ) if fit: image = self._fit_image(image,(width,height)) else: image = self._squeeze_image(image) return image def _create_init_image(self,image): image = image.convert('RGB') # print( # f'>> DEBUG: writing the image to img.png' # ) # image.save('img.png') image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) image = 2.0 * image - 1.0 return image.to(self.device) def _create_init_mask(self, image): # convert into a black/white mask image = self._image_to_mask(image) image = image.convert('RGB') # BUG: We need to use the model's downsample factor rather than hardcoding "8" from src.stablediffusion.ldm.dream.generator.base import downsampling image = image.resize((image.width//downsampling, image.height//downsampling), resample=Image.Resampling.LANCZOS) # print( # f'>> DEBUG: writing the mask to mask.png' # ) # image.save('mask.png') image = np.array(image) image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image.to(self.device) # The mask is expected to have the region to be inpainted # with alpha transparency. It converts it into a black/white # image with the transparent part black. def _image_to_mask(self, mask_image, invert=False) -> Image: # Obtain the mask from the transparency channel mask = Image.new(mode="L", size=mask_image.size, color=255) mask.putdata(mask_image.getdata(band=3)) if invert: mask = ImageOps.invert(mask) return mask def _has_transparency(self,image): if image.info.get("transparency", None) is not None: return True if image.mode == "P": transparent = image.info.get("transparency", -1) for _, index in image.getcolors(): if index == transparent: return True elif image.mode == "RGBA": extrema = image.getextrema() if extrema[3][0] < 255: return True return False def _check_for_erasure(self,image): width, height = image.size pixdata = image.load() colored = 0 for y in range(height): for x in range(width): if pixdata[x, y][3] == 0: r, g, b, _ = pixdata[x, y] if (r, g, b) != (0, 0, 0) and \ (r, g, b) != (255, 255, 255): colored += 1 return colored == 0 def _squeeze_image(self,image): x,y,resize_needed = self._resolution_check(image.width,image.height) if resize_needed: return InitImageResizer(image).resize(x,y) return image def _fit_image(self,image,max_dimensions): w,h = max_dimensions print( f'>> image will be resized to fit inside a box {w}x{h} in size.' ) if image.width > image.height: h = None # by setting h to none, we tell InitImageResizer to fit into the width and calculate height elif image.height > image.width: w = None # ditto for w else: pass image = InitImageResizer(image).resize(w,h) # note that InitImageResizer does the multiple of 64 truncation internally print( f'>> after adjusting image dimensions to be multiples of 64, init image is {image.width}x{image.height}' ) return image def _resolution_check(self, width, height, log=False): resize_needed = False w, h = map( lambda x: x - x % 64, (width, height) ) # resize to integer multiple of 64 if h != height or w != width: if log: print( f'>> Provided width and height must be multiples of 64. Auto-resizing to {w}x{h}' ) height = h width = w resize_needed = True if (width * height) > (self.width * self.height): print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") return width, height, resize_needed ================================================ FILE: src/stablediffusion/ldm/gfpgan/gfpgan_tools.py ================================================ import torch import warnings import os import sys import numpy as np from PIL import Image from scripts.dream import create_argv_parser arg_parser = create_argv_parser() opt = arg_parser.parse_args() model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path) gfpgan_model_exists = os.path.isfile(model_path) def run_gfpgan(image, strength, seed, upsampler_scale=4): print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') gfpgan = None with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=UserWarning) try: if not gfpgan_model_exists: raise Exception('GFPGAN model not found at path ' + model_path) sys.path.append(os.path.abspath(opt.gfpgan_dir)) from gfpgan import GFPGANer bg_upsampler = _load_gfpgan_bg_upsampler( opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile ) gfpgan = GFPGANer( model_path=model_path, upscale=upsampler_scale, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler, ) except Exception: import traceback print('>> Error loading GFPGAN:', file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) if gfpgan is None: print( f'>> WARNING: GFPGAN not initialized.' ) print( f'>> Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth to {model_path}, \nor change GFPGAN directory with --gfpgan_dir.' ) return image image = image.convert('RGB') cropped_faces, restored_faces, restored_img = gfpgan.enhance( np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True, ) res = Image.fromarray(restored_img) if strength < 1.0: # Resize the image to the new image if the sizes have changed if restored_img.size != image.size: image = image.resize(res.size) res = Image.blend(image, res, strength) if torch.cuda.is_available(): torch.cuda.empty_cache() gfpgan = None return res def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400): if bg_upsampler == 'realesrgan': if not torch.cuda.is_available(): # CPU or MPS on M1 use_half_precision = False else: use_half_precision = True model_path = { 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', } if upsampler_scale not in model_path: return None from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer if upsampler_scale == 4: model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) if upsampler_scale == 2: model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2, ) bg_upsampler = RealESRGANer( scale=upsampler_scale, model_path=model_path[upsampler_scale], model=model, tile=bg_tile, tile_pad=10, pre_pad=0, half=use_half_precision, ) else: bg_upsampler = None return bg_upsampler def real_esrgan_upscale(image, strength, upsampler_scale, seed): print( f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x' ) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=UserWarning) try: upsampler = _load_gfpgan_bg_upsampler( opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile ) except Exception: import traceback print('>> Error loading Real-ESRGAN:', file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) output, img_mode = upsampler.enhance( np.array(image, dtype=np.uint8), outscale=upsampler_scale, alpha_upsampler=opt.gfpgan_bg_upsampler, ) res = Image.fromarray(output) if strength < 1.0: # Resize the image to the new image if the sizes have changed if output.size != image.size: image = image.resize(res.size) res = Image.blend(image, res, strength) if torch.cuda.is_available(): torch.cuda.empty_cache() upsampler = None return res ================================================ FILE: src/stablediffusion/ldm/lr_scheduler.py ================================================ import numpy as np class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ def __init__( self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0, ): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps self.last_lr = 0.0 self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f'current step: {n}, recent lr-multiplier: {self.last_lr}' ) if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start ) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr return lr else: t = (n - self.lr_warm_up_steps) / ( self.lr_max_decay_steps - self.lr_warm_up_steps ) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi) ) self.last_lr = lr return lr def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaWarmUpCosineScheduler2: """ supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ def __init__( self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0, ): assert ( len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) ) self.lr_warm_up_steps = warm_up_steps self.f_start = f_start self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.last_f = 0.0 self.verbosity_interval = verbosity_interval def find_in_interval(self, n): interval = 0 for cl in self.cum_cycles[1:]: if n <= cl: return interval interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f'current step: {n}, recent lr-multiplier: {self.last_f}, ' f'current cycle {cycle}' ) if n < self.lr_warm_up_steps[cycle]: f = ( self.f_max[cycle] - self.f_start[cycle] ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: t = (n - self.lr_warm_up_steps[cycle]) / ( self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] ) t = min(t, 1.0) f = self.f_min[cycle] + 0.5 * ( self.f_max[cycle] - self.f_min[cycle] ) * (1 + np.cos(t * np.pi)) self.last_f = f return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print( f'current step: {n}, recent lr-multiplier: {self.last_f}, ' f'current cycle {cycle}' ) if n < self.lr_warm_up_steps[cycle]: f = ( self.f_max[cycle] - self.f_start[cycle] ) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( self.cycle_lengths[cycle] - n ) / (self.cycle_lengths[cycle]) self.last_f = f return f ================================================ FILE: src/stablediffusion/ldm/models/autoencoder.py ================================================ import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from src.stablediffusion.ldm.modules.diffusionmodules.model import Encoder, Decoder from src.stablediffusion.ldm.modules.distributions.distributions import ( DiagonalGaussianDistribution, ) from src.stablediffusion.ldm.util import instantiate_from_config class VQModel(pl.LightningModule): def __init__( self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key='image', colorize_nlabels=None, monitor=None, batch_resize_range=None, scheduler_config=None, lr_g_factor=1.0, remap=None, sane_index_shape=False, # tell vector quantizer to return indices as bhw use_ema=False, ): super().__init__() self.embed_dim = embed_dim self.n_embed = n_embed self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer( n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape, ) self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d( embed_dim, ddconfig['z_channels'], 1 ) if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer( 'colorize', torch.randn(3, colorize_nlabels, 1, 1) ) if monitor is not None: self.monitor = monitor self.batch_resize_range = batch_resize_range if self.batch_resize_range is not None: print( f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.' ) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f'{context}: Switched to EMA weights') try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f'{context}: Restored training weights') def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print('Deleting key {} from state_dict.'.format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print( f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) if len(missing) > 0: print(f'Missing Keys: {missing}') print(f'Unexpected Keys: {unexpected}') def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x): h = self.encoder(x) h = self.quant_conv(h) quant, emb_loss, info = self.quantize(h) return quant, emb_loss, info def encode_to_prequant(self, x): h = self.encoder(x) h = self.quant_conv(h) return h def decode(self, quant): quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec def decode_code(self, code_b): quant_b = self.quantize.embed_code(code_b) dec = self.decode(quant_b) return dec def forward(self, input, return_pred_indices=False): quant, diff, (_, _, ind) = self.encode(input) dec = self.decode(quant) if return_pred_indices: return dec, diff, ind return dec, diff def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = ( x.permute(0, 3, 1, 2) .to(memory_format=torch.contiguous_format) .float() ) if self.batch_resize_range is not None: lower_size = self.batch_resize_range[0] upper_size = self.batch_resize_range[1] if self.global_step <= 4: # do the first few batches with max size to avoid later oom new_resize = upper_size else: new_resize = np.random.choice( np.arange(lower_size, upper_size + 16, 16) ) if new_resize != x.shape[2]: x = F.interpolate(x, size=new_resize, mode='bicubic') x = x.detach() return x def training_step(self, batch, batch_idx, optimizer_idx): # https://github.com/pytorch/pytorch/issues/37142 # try not to fool the heuristics x = self.get_input(batch, self.image_key) xrec, qloss, ind = self(x, return_pred_indices=True) if optimizer_idx == 0: # autoencode aeloss, log_dict_ae = self.loss( qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split='train', predicted_indices=ind, ) self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True, ) return aeloss if optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss( qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split='train', ) self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True, ) return discloss def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): log_dict_ema = self._validation_step( batch, batch_idx, suffix='_ema' ) return log_dict def _validation_step(self, batch, batch_idx, suffix=''): x = self.get_input(batch, self.image_key) xrec, qloss, ind = self(x, return_pred_indices=True) aeloss, log_dict_ae = self.loss( qloss, x, xrec, 0, self.global_step, last_layer=self.get_last_layer(), split='val' + suffix, predicted_indices=ind, ) discloss, log_dict_disc = self.loss( qloss, x, xrec, 1, self.global_step, last_layer=self.get_last_layer(), split='val' + suffix, predicted_indices=ind, ) rec_loss = log_dict_ae[f'val{suffix}/rec_loss'] self.log( f'val{suffix}/rec_loss', rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True, ) self.log( f'val{suffix}/aeloss', aeloss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True, ) if version.parse(pl.__version__) >= version.parse('1.4.0'): del log_dict_ae[f'val{suffix}/rec_loss'] self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr_d = self.learning_rate lr_g = self.lr_g_factor * self.learning_rate print('lr_d', lr_d) print('lr_g', lr_g) opt_ae = torch.optim.Adam( list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.quantize.parameters()) + list(self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()), lr=lr_g, betas=(0.5, 0.9), ) opt_disc = torch.optim.Adam( self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9) ) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print('Setting up LambdaLR scheduler...') scheduler = [ { 'scheduler': LambdaLR( opt_ae, lr_lambda=scheduler.schedule ), 'interval': 'step', 'frequency': 1, }, { 'scheduler': LambdaLR( opt_disc, lr_lambda=scheduler.schedule ), 'interval': 'step', 'frequency': 1, }, ] return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], [] def get_last_layer(self): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: log['inputs'] = x return log xrec, _ = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) log['inputs'] = x log['reconstructions'] = xrec if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) log['reconstructions_ema'] = xrec_ema return log def to_rgb(self, x): assert self.image_key == 'segmentation' if not hasattr(self, 'colorize'): self.register_buffer( 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) ) x = F.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): super().__init__(embed_dim=embed_dim, *args, **kwargs) self.embed_dim = embed_dim def encode(self, x): h = self.encoder(x) h = self.quant_conv(h) return h def decode(self, h, force_not_quantize=False): # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) else: quant = h quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec class AutoencoderKL(pl.LightningModule): def __init__( self, ddconfig, lossconfig, embed_dim, ckpt_path=None, ignore_keys=[], image_key='image', colorize_nlabels=None, monitor=None, ): super().__init__() self.image_key = image_key self.encoder = Encoder(**ddconfig) self.decoder = Decoder(**ddconfig) self.loss = instantiate_from_config(lossconfig) assert ddconfig['double_z'] self.quant_conv = torch.nn.Conv2d( 2 * ddconfig['z_channels'], 2 * embed_dim, 1 ) self.post_quant_conv = torch.nn.Conv2d( embed_dim, ddconfig['z_channels'], 1 ) self.embed_dim = embed_dim if colorize_nlabels is not None: assert type(colorize_nlabels) == int self.register_buffer( 'colorize', torch.randn(3, colorize_nlabels, 1, 1) ) if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location='cpu')['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print('Deleting key {} from state_dict.'.format(k)) del sd[k] self.load_state_dict(sd, strict=False) print(f'Restored from {path}') def encode(self, x): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z): z = self.post_quant_conv(z) dec = self.decoder(z) return dec def forward(self, input, sample_posterior=True): posterior = self.encode(input) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) return dec, posterior def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = ( x.permute(0, 3, 1, 2) .to(memory_format=torch.contiguous_format) .float() ) return x def training_step(self, batch, batch_idx, optimizer_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) if optimizer_idx == 0: # train encoder+decoder+logvar aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split='train', ) self.log( 'aeloss', aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, ) return aeloss if optimizer_idx == 1: # train the discriminator discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split='train', ) self.log( 'discloss', discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, ) self.log_dict( log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False, ) return discloss def validation_step(self, batch, batch_idx): inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) aeloss, log_dict_ae = self.loss( inputs, reconstructions, posterior, 0, self.global_step, last_layer=self.get_last_layer(), split='val', ) discloss, log_dict_disc = self.loss( inputs, reconstructions, posterior, 1, self.global_step, last_layer=self.get_last_layer(), split='val', ) self.log('val/rec_loss', log_dict_ae['val/rec_loss']) self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr = self.learning_rate opt_ae = torch.optim.Adam( list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()), lr=lr, betas=(0.5, 0.9), ) opt_disc = torch.optim.Adam( self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) ) return [opt_ae, opt_disc], [] def get_last_layer(self): return self.decoder.conv_out.weight @torch.no_grad() def log_images(self, batch, only_inputs=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x = x.to(self.device) if not only_inputs: xrec, posterior = self(x) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) log['samples'] = self.decode(torch.randn_like(posterior.sample())) log['reconstructions'] = xrec log['inputs'] = x return log def to_rgb(self, x): assert self.image_key == 'segmentation' if not hasattr(self, 'colorize'): self.register_buffer( 'colorize', torch.randn(3, x.shape[1], 1, 1).to(x) ) x = F.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x class IdentityFirstStage(torch.nn.Module): def __init__(self, *args, vq_interface=False, **kwargs): self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff super().__init__() def encode(self, x, *args, **kwargs): return x def decode(self, x, *args, **kwargs): return x def quantize(self, x, *args, **kwargs): if self.vq_interface: return x, None, [None, None, None] return x def forward(self, x, *args, **kwargs): return x ================================================ FILE: src/stablediffusion/ldm/models/diffusion/__init__.py ================================================ ================================================ FILE: src/stablediffusion/ldm/models/diffusion/classifier.py ================================================ import os import torch import pytorch_lightning as pl from omegaconf import OmegaConf from torch.nn import functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import LambdaLR from copy import deepcopy from einops import rearrange from glob import glob from natsort import natsorted from src.stablediffusion.ldm.modules.diffusionmodules.openaimodel import ( EncoderUNetModel, UNetModel, ) from src.stablediffusion.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config __models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel} def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class NoisyLatentImageClassifier(pl.LightningModule): def __init__( self, diffusion_path, num_classes, ckpt_path=None, pool='attention', label_key=None, diffusion_ckpt_path=None, scheduler_config=None, weight_decay=1.0e-2, log_steps=10, monitor='val/loss', *args, **kwargs, ): super().__init__(*args, **kwargs) self.num_classes = num_classes # get latest config of diffusion model diffusion_config = natsorted( glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')) )[-1] self.diffusion_config = OmegaConf.load(diffusion_config).model self.diffusion_config.params.ckpt_path = diffusion_ckpt_path self.load_diffusion() self.monitor = monitor self.numd = ( self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 ) self.log_time_interval = ( self.diffusion_model.num_timesteps // log_steps ) self.log_steps = log_steps self.label_key = ( label_key if not hasattr(self.diffusion_model, 'cond_stage_key') else self.diffusion_model.cond_stage_key ) assert ( self.label_key is not None ), 'label_key neither in diffusion model nor in model.params' if self.label_key not in __models__: raise NotImplementedError() self.load_classifier(ckpt_path, pool) self.scheduler_config = scheduler_config self.use_scheduler = self.scheduler_config is not None self.weight_decay = weight_decay def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location='cpu') if 'state_dict' in list(sd.keys()): sd = sd['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print('Deleting key {} from state_dict.'.format(k)) del sd[k] missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) ) print( f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) if len(missing) > 0: print(f'Missing Keys: {missing}') if len(unexpected) > 0: print(f'Unexpected Keys: {unexpected}') def load_diffusion(self): model = instantiate_from_config(self.diffusion_config) self.diffusion_model = model.eval() self.diffusion_model.train = disabled_train for param in self.diffusion_model.parameters(): param.requires_grad = False def load_classifier(self, ckpt_path, pool): model_config = deepcopy( self.diffusion_config.params.unet_config.params ) model_config.in_channels = ( self.diffusion_config.params.unet_config.params.out_channels ) model_config.out_channels = self.num_classes if self.label_key == 'class_label': model_config.pool = pool self.model = __models__[self.label_key](**model_config) if ckpt_path is not None: print( '#####################################################################' ) print(f'load from ckpt "{ckpt_path}"') print( '#####################################################################' ) self.init_from_ckpt(ckpt_path) @torch.no_grad() def get_x_noisy(self, x, t, noise=None): noise = default(noise, lambda: torch.randn_like(x)) continuous_sqrt_alpha_cumprod = None if self.diffusion_model.use_continuous_noise: continuous_sqrt_alpha_cumprod = ( self.diffusion_model.sample_continuous_noise_level( x.shape[0], t + 1 ) ) # todo: make sure t+1 is correct here return self.diffusion_model.q_sample( x_start=x, t=t, noise=noise, continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod, ) def forward(self, x_noisy, t, *args, **kwargs): return self.model(x_noisy, t) @torch.no_grad() def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = rearrange(x, 'b h w c -> b c h w') x = x.to(memory_format=torch.contiguous_format).float() return x @torch.no_grad() def get_conditioning(self, batch, k=None): if k is None: k = self.label_key assert k is not None, 'Needs to provide label key' targets = batch[k].to(self.device) if self.label_key == 'segmentation': targets = rearrange(targets, 'b h w c -> b c h w') for down in range(self.numd): h, w = targets.shape[-2:] targets = F.interpolate( targets, size=(h // 2, w // 2), mode='nearest' ) # targets = rearrange(targets,'b c h w -> b h w c') return targets def compute_top_k(self, logits, labels, k, reduction='mean'): _, top_ks = torch.topk(logits, k, dim=1) if reduction == 'mean': return ( (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() ) elif reduction == 'none': return (top_ks == labels[:, None]).float().sum(dim=-1) def on_train_epoch_start(self): # save some memory self.diffusion_model.model.to('cpu') @torch.no_grad() def write_logs(self, loss, logits, targets): log_prefix = 'train' if self.training else 'val' log = {} log[f'{log_prefix}/loss'] = loss.mean() log[f'{log_prefix}/acc@1'] = self.compute_top_k( logits, targets, k=1, reduction='mean' ) log[f'{log_prefix}/acc@5'] = self.compute_top_k( logits, targets, k=5, reduction='mean' ) self.log_dict( log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True, ) self.log( 'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False ) self.log( 'global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True, ) lr = self.optimizers().param_groups[0]['lr'] self.log( 'lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True, ) def shared_step(self, batch, t=None): x, *_ = self.diffusion_model.get_input( batch, k=self.diffusion_model.first_stage_key ) targets = self.get_conditioning(batch) if targets.dim() == 4: targets = targets.argmax(dim=1) if t is None: t = torch.randint( 0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device, ).long() else: t = torch.full( size=(x.shape[0],), fill_value=t, device=self.device ).long() x_noisy = self.get_x_noisy(x, t) logits = self(x_noisy, t) loss = F.cross_entropy(logits, targets, reduction='none') self.write_logs(loss.detach(), logits.detach(), targets.detach()) loss = loss.mean() return loss, logits, x_noisy, targets def training_step(self, batch, batch_idx): loss, *_ = self.shared_step(batch) return loss def reset_noise_accs(self): self.noisy_acc = { t: {'acc@1': [], 'acc@5': []} for t in range( 0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t, ) } def on_validation_start(self): self.reset_noise_accs() @torch.no_grad() def validation_step(self, batch, batch_idx): loss, *_ = self.shared_step(batch) for t in self.noisy_acc: _, logits, _, targets = self.shared_step(batch, t) self.noisy_acc[t]['acc@1'].append( self.compute_top_k(logits, targets, k=1, reduction='mean') ) self.noisy_acc[t]['acc@5'].append( self.compute_top_k(logits, targets, k=5, reduction='mean') ) return loss def configure_optimizers(self): optimizer = AdamW( self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, ) if self.use_scheduler: scheduler = instantiate_from_config(self.scheduler_config) print('Setting up LambdaLR scheduler...') scheduler = [ { 'scheduler': LambdaLR( optimizer, lr_lambda=scheduler.schedule ), 'interval': 'step', 'frequency': 1, } ] return [optimizer], scheduler return optimizer @torch.no_grad() def log_images(self, batch, N=8, *args, **kwargs): log = dict() x = self.get_input(batch, self.diffusion_model.first_stage_key) log['inputs'] = x y = self.get_conditioning(batch) if self.label_key == 'class_label': y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label']) log['labels'] = y if ismap(y): log['labels'] = self.diffusion_model.to_rgb(y) for step in range(self.log_steps): current_time = step * self.log_time_interval _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) log[f'inputs@t{current_time}'] = x_noisy pred = F.one_hot( logits.argmax(dim=1), num_classes=self.num_classes ) pred = rearrange(pred, 'b h w c -> b c h w') log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb( pred ) for key in log: log[key] = log[key][:N] return log ================================================ FILE: src/stablediffusion/ldm/models/diffusion/ddim.py ================================================ """SAMPLING ONLY.""" import torch import numpy as np from tqdm import tqdm from functools import partial from src.stablediffusion.ldm.dream.devices import choose_torch_device from src.stablediffusion.ldm.modules.diffusionmodules.util import ( make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor, ) class DDIMSampler(object): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule self.device = device or choose_torch_device() def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device(self.device): attr = attr.to(dtype=torch.float32, device=self.device) setattr(self, name, attr) def make_schedule( self, ddim_num_steps, ddim_discretize='uniform', ddim_eta=0.0, verbose=True, ): self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose, ) alphas_cumprod = self.model.alphas_cumprod assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), 'alphas have to be defined for each timestep' to_torch = ( lambda x: x.clone() .detach() .to(torch.float32) .to(self.model.device) ) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer( 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) ) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) ) self.register_buffer( 'sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), ) self.register_buffer( 'log_one_minus_alphas_cumprod', to_torch(np.log(1.0 - alphas_cumprod.cpu())), ) self.register_buffer( 'sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), ) self.register_buffer( 'sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), ) # ddim sampling parameters ( ddim_sigmas, ddim_alphas, ddim_alphas_prev, ) = make_ddim_sampling_parameters( alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose, ) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer( 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) ) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) ) self.register_buffer( 'ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps, ) @torch.no_grad() def sample( self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: print( f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' ) else: if conditioning.shape[0] != batch_size: print( f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' ) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) print(f'Data shape for DDIM sampling is {size}, eta {eta}') samples, intermediates = self.ddim_sampling( conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) return samples, intermediates # This routine gets called from img2img @torch.no_grad() def ddim_sampling( self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} time_range = ( reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) ) total_steps = ( timesteps if ddim_use_original_steps else timesteps.shape[0] ) print(f'Running DDIM Sampling with {total_steps} timesteps') iterator = tqdm( time_range, desc='DDIM Sampler', total=total_steps, dynamic_ncols=True, ) for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) if mask is not None: assert x0 is not None img_orig = self.model.q_sample( x0, ts ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img outs = self.p_sample_ddim( img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) return img, intermediates # This routine gets called from ddim_sampling() and decode() @torch.no_grad() def p_sample_ddim( self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): b, *_, device = *x.shape, x.device if ( unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( self.model, e_t, x, t, c, **corrector_kwargs ) alphas = ( self.model.alphas_cumprod if use_original_steps else self.ddim_alphas ) alphas_prev = ( self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas ) # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = ( sigma_t * noise_like(x.shape, device, repeat_noise) * temperature ) if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 @torch.no_grad() def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): # fast, but does not allow for exact reconstruction # t serves as an index to gather the correct alphas if use_original_steps: sqrt_alphas_cumprod = self.sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod else: sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas if noise is None: noise = torch.randn_like(x0) return ( extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise ) @torch.no_grad() def decode( self, x_latent, cond, t_start, img_callback=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, use_original_steps=False, init_latent = None, mask = None, ): timesteps = ( np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps ) timesteps = timesteps[:t_start] time_range = np.flip(timesteps) total_steps = timesteps.shape[0] print(f'Running DDIM Sampling with {total_steps} timesteps') iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full( (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long, ) if mask is not None: assert x0 is not None xdec_orig = self.model.q_sample( x0, ts ) # TODO: deterministic forward pass? x_dec = xdec_orig * mask + (1.0 - mask) * x_dec x_dec, _ = self.p_sample_ddim( x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) if img_callback: img_callback(x_dec, i) return x_dec ================================================ FILE: src/stablediffusion/ldm/models/diffusion/ddpm.py ================================================ """ wild mixture of https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py https://github.com/CompVis/taming-transformers -- merci """ import torch import torch.nn as nn import os import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat from contextlib import contextmanager from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only import urllib from src.stablediffusion.ldm.util import ( log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config, ) from src.stablediffusion.ldm.modules.ema import LitEma from src.stablediffusion.ldm.modules.distributions.distributions import ( normal_kl, DiagonalGaussianDistribution, ) from src.stablediffusion.ldm.models.autoencoder import ( VQModelInterface, IdentityFirstStage, AutoencoderKL, ) from src.stablediffusion.ldm.modules.diffusionmodules.util import ( make_beta_schedule, extract_into_tensor, noise_like, ) from src.stablediffusion.ldm.models.diffusion.ddim import DDIMSampler __conditioning_keys__ = { 'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y', } def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__( self, unet_config, timesteps=1000, beta_schedule='linear', loss_type='l2', ckpt_path=None, ignore_keys=[], load_only_unet=False, monitor='val/loss', use_ema=True, first_stage_key='image', image_size=256, channels=3, log_every_t=100, clip_denoised=True, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, given_betas=None, original_elbo_weight=0.0, embedding_reg_weight=0.0, v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1.0, conditioning_key=None, parameterization='eps', # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, logvar_init=0.0, ): super().__init__() assert parameterization in [ 'eps', 'x0', ], 'currently only supporting "eps" and "x0"' self.parameterization = parameterization print( f'{self.__class__.__name__}: Running in {self.parameterization}-prediction mode' ) self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key self.image_size = image_size # try conv? self.channels = channels self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config, conditioning_key) count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.') self.use_scheduler = scheduler_config is not None if self.use_scheduler: self.scheduler_config = scheduler_config self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight self.embedding_reg_weight = embedding_reg_weight if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt( ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet ) self.register_schedule( given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) self.loss_type = loss_type self.learn_logvar = learn_logvar self.logvar = torch.full( fill_value=logvar_init, size=(self.num_timesteps,) ) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) def register_schedule( self, given_betas=None, beta_schedule='linear', timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule( beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, ) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) (timesteps,) = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert ( alphas_cumprod.shape[0] == self.num_timesteps ), 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer( 'alphas_cumprod_prev', to_torch(alphas_cumprod_prev) ) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)) ) self.register_buffer( 'sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod)), ) self.register_buffer( 'log_one_minus_alphas_cumprod', to_torch(np.log(1.0 - alphas_cumprod)), ) self.register_buffer( 'sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod)), ) self.register_buffer( 'sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod - 1)), ) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * ( 1.0 - alphas_cumprod_prev ) / (1.0 - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer( 'posterior_variance', to_torch(posterior_variance) ) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer( 'posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))), ) self.register_buffer( 'posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) ), ) self.register_buffer( 'posterior_mean_coef2', to_torch( (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) ), ) if self.parameterization == 'eps': lvlb_weights = self.betas**2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) ) elif self.parameterization == 'x0': lvlb_weights = ( 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) ) else: raise NotImplementedError('mu not supported') # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f'{context}: Switched to EMA weights') try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f'{context}: Restored training weights') def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location='cpu') if 'state_dict' in list(sd.keys()): sd = sd['state_dict'] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print('Deleting key {} from state_dict.'.format(k)) del sd[k] missing, unexpected = ( self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) ) print( f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys' ) if len(missing) > 0: print(f'Missing Keys: {missing}') if len(unexpected) > 0: print(f'Unexpected Keys: {unexpected}') def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = extract_into_tensor( 1.0 - self.alphas_cumprod, t, x_start.shape ) log_variance = extract_into_tensor( self.log_one_minus_alphas_cumprod, t, x_start.shape ) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor( self.sqrt_recipm1_alphas_cumprod, t, x_t.shape ) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract_into_tensor( self.posterior_variance, t, x_t.shape ) posterior_log_variance_clipped = extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) return ( posterior_mean, posterior_variance, posterior_log_variance_clipped, ) def p_mean_variance(self, x, t, clip_denoised: bool): model_out = self.model(x, t) if self.parameterization == 'eps': x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == 'x0': x_recon = model_out if clip_denoised: x_recon.clamp_(-1.0, 1.0) ( model_mean, posterior_variance, posterior_log_variance, ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance( x=x, t=t, clip_denoised=clip_denoised ) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape( b, *((1,) * (len(x.shape) - 1)) ) return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise ) @torch.no_grad() def p_sample_loop(self, shape, return_intermediates=False): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm( reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps, dynamic_ncols=True, ): img = self.p_sample( img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised, ) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels return self.p_sample_loop( (batch_size, channels, image_size, image_size), return_intermediates=return_intermediates, ) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor( self.sqrt_one_minus_alphas_cumprod, t, x_start.shape ) * noise ) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == 'l2': if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss( target, pred, reduction='none' ) else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def p_losses(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) loss_dict = {} if self.parameterization == 'eps': target = noise elif self.parameterization == 'x0': target = x_start else: raise NotImplementedError( f'Paramterization {self.parameterization} not yet supported' ) loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) log_prefix = 'train' if self.training else 'val' loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb loss_dict.update({f'{log_prefix}/loss': loss}) return loss, loss_dict def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = rearrange(x, 'b h w c -> b c h w') x = x.to(memory_format=torch.contiguous_format).float() return x def shared_step(self, batch): x = self.get_input(batch, self.first_stage_key) loss, loss_dict = self(x) return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict( loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True ) self.log( 'global_step', self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, ) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] self.log( 'lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, ) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = { key + '_ema': loss_dict_ema[key] for key in loss_dict_ema } self.log_dict( loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, ) self.log_dict( loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, ) def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @torch.no_grad() def log_images( self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs ): log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] log['inputs'] = x # get diffusion row diffusion_row = list() x_start = x[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), '1 -> b', b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) diffusion_row.append(x_noisy) log['diffusion_row'] = self._get_rows_from_list(diffusion_row) if sample: # get denoise row with self.ema_scope('Plotting'): samples, denoise_row = self.sample( batch_size=N, return_intermediates=True ) log['samples'] = samples log['denoise_row'] = self._get_rows_from_list(denoise_row) if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: params = params + [self.logvar] opt = torch.optim.AdamW(params, lr=lr) return opt class LatentDiffusion(DDPM): """main class""" def __init__( self, first_stage_config, cond_stage_config, personalization_config, num_timesteps_cond=None, cond_stage_key='image', cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, scale_by_std=False, *args, **kwargs, ): self.num_timesteps_cond = default(num_timesteps_cond, 1) self.scale_by_std = scale_by_std assert self.num_timesteps_cond <= kwargs['timesteps'] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__': conditioning_key = None ckpt_path = kwargs.pop('ckpt_path', None) ignore_keys = kwargs.pop('ignore_keys', []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = ( len(first_stage_config.params.ddconfig.ch_mult) - 1 ) except: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor else: self.register_buffer('scale_factor', torch.tensor(scale_factor)) self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): param.requires_grad = False self.model.eval() self.model.train = disabled_train for param in self.model.parameters(): param.requires_grad = False self.embedding_manager = self.instantiate_embedding_manager( personalization_config, self.cond_stage_model ) self.emb_ckpt_counter = 0 # if self.embedding_manager.is_clip: # self.cond_stage_model.update_embedding_func(self.embedding_manager) for param in self.embedding_manager.embedding_parameters(): param.requires_grad = True def make_cond_schedule( self, ): self.cond_ids = torch.full( size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long, ) ids = torch.round( torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) ).long() self.cond_ids[: self.num_timesteps_cond] = ids @rank_zero_only @torch.no_grad() def on_train_batch_start(self, batch, batch_idx, dataloader_idx): # only for very first batch if ( self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt ): assert ( self.scale_factor == 1.0 ), 'rather not use custom rescaling and std-rescaling simultaneously' # set rescale weight to 1./std of encodings print('### USING STD-RESCALING ###') x = super().get_input(batch, self.first_stage_key) x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() del self.scale_factor self.register_buffer('scale_factor', 1.0 / z.flatten().std()) print(f'setting self.scale_factor to {self.scale_factor}') print('### USING STD-RESCALING ###') def register_schedule( self, given_betas=None, beta_schedule='linear', timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, ): super().register_schedule( given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s, ) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == '__is_first_stage__': print('Using first stage also as cond stage.') self.cond_stage_model = self.first_stage_model elif config == '__is_unconditional__': print( f'Training {self.__class__.__name__} as an unconditional model.' ) self.cond_stage_model = None # self.be_unconditional = True else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): param.requires_grad = False else: assert config != '__is_first_stage__' assert config != '__is_unconditional__' try: model = instantiate_from_config(config) except urllib.error.URLError: raise SystemExit( "* Couldn't load a dependency. Try running scripts/preload_models.py from an internet-conected machine." ) self.cond_stage_model = model def instantiate_embedding_manager(self, config, embedder): model = instantiate_from_config(config, embedder=embedder) if config.params.get( 'embedding_manager_ckpt', None ): # do not load if missing OR empty string model.load(config.params.embedding_manager_ckpt) return model def _get_denoise_row_from_list( self, samples, desc='', force_no_decoder_quantization=False ): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append( self.decode_first_stage( zd.to(self.device), force_not_quantize=force_no_decoder_quantization, ) ) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.sample() elif isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError( f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" ) return self.scale_factor * z def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( c, embedding_manager=self.embedding_manager ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: c = self.cond_stage_model(c) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) return c def meshgrid(self, h, w): y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) arr = torch.cat([y, x], dim=-1) return arr def delta_border(self, h, w): """ :param h: height :param w: width :return: normalized distance to image border, wtith min distance = 0 at border and max dist = 0.5 at image center """ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) arr = self.meshgrid(h, w) / lower_right_corner dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] edge_dist = torch.min( torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1 )[0] return edge_dist def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) weighting = torch.clip( weighting, self.split_input_params['clip_min_weight'], self.split_input_params['clip_max_weight'], ) weighting = ( weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) ) if self.split_input_params['tie_braker']: L_weighting = self.delta_border(Ly, Lx) L_weighting = torch.clip( L_weighting, self.split_input_params['clip_min_tie_weight'], self.split_input_params['clip_max_tie_weight'], ) L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) weighting = weighting * L_weighting return weighting def get_fold_unfold( self, x, kernel_size, stride, uf=1, df=1 ): # todo load once not every time, shorten code """ :param x: img of size (bs, c, h, w) :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) """ bs, nc, h, w = x.shape # number of crops in image Ly = (h - kernel_size[0]) // stride[0] + 1 Lx = (w - kernel_size[1]) // stride[1] + 1 if uf == 1 and df == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) weighting = self.get_weighting( kernel_size[0], kernel_size[1], Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view( 1, 1, h, w ) # normalizes the overlap weighting = weighting.view( (1, 1, kernel_size[0], kernel_size[1], Ly * Lx) ) elif uf > 1 and df == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict( kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), dilation=1, padding=0, stride=(stride[0] * uf, stride[1] * uf), ) fold = torch.nn.Fold( output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2 ) weighting = self.get_weighting( kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view( 1, 1, h * uf, w * uf ) # normalizes the overlap weighting = weighting.view( (1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx) ) elif df > 1 and uf == 1: fold_params = dict( kernel_size=kernel_size, dilation=1, padding=0, stride=stride ) unfold = torch.nn.Unfold(**fold_params) fold_params2 = dict( kernel_size=(kernel_size[0] // df, kernel_size[0] // df), dilation=1, padding=0, stride=(stride[0] // df, stride[1] // df), ) fold = torch.nn.Fold( output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2, ) weighting = self.get_weighting( kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device ).to(x.dtype) normalization = fold(weighting).view( 1, 1, h // df, w // df ) # normalizes the overlap weighting = weighting.view( (1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx) ) else: raise NotImplementedError return fold, unfold, normalization, weighting @torch.no_grad() def get_input( self, batch, k, return_first_stage_outputs=False, force_c_encode=False, cond_key=None, return_original_cond=False, bs=None, ): x = super().get_input(batch, k) if bs is not None: x = x[:bs] x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() if self.model.conditioning_key is not None: if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: if cond_key in ['caption', 'coordinates_bbox']: xc = batch[cond_key] elif cond_key == 'class_label': xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) else: xc = x if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: c = self.get_learned_conditioning(xc.to(self.device)) else: c = xc if bs is not None: c = c[:bs] if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) c = {'pos_x': pos_x, 'pos_y': pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) out.extend([x, xrec]) if return_original_cond: out.append(xc) return out @torch.no_grad() def decode_first_stage( self, z, predict_cids=False, force_not_quantize=False ): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry( z, shape=None ) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1.0 / self.scale_factor * z if hasattr(self, 'split_input_params'): if self.split_input_params['patch_distributed_vq']: ks = self.split_input_params['ks'] # eg. (128, 128) stride = self.split_input_params['stride'] # eg. (64, 64) uf = self.split_input_params['vqf'] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print('reducing stride') fold, unfold, normalization, weighting = self.get_fold_unfold( z, ks, stride, uf=uf ) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape z = z.view( (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) ) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): output_list = [ self.first_stage_model.decode( z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize, ) for i in range(z.shape[-1]) ] else: output_list = [ self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1]) ] o = torch.stack( output_list, axis=-1 ) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape o = o.view( (o.shape[0], -1, o.shape[-1]) ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode( z, force_not_quantize=predict_cids or force_not_quantize, ) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode( z, force_not_quantize=predict_cids or force_not_quantize ) else: return self.first_stage_model.decode(z) # same as above but without decorator def differentiable_decode_first_stage( self, z, predict_cids=False, force_not_quantize=False ): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry( z, shape=None ) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1.0 / self.scale_factor * z if hasattr(self, 'split_input_params'): if self.split_input_params['patch_distributed_vq']: ks = self.split_input_params['ks'] # eg. (128, 128) stride = self.split_input_params['stride'] # eg. (64, 64) uf = self.split_input_params['vqf'] bs, nc, h, w = z.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print('reducing stride') fold, unfold, normalization, weighting = self.get_fold_unfold( z, ks, stride, uf=uf ) z = unfold(z) # (bn, nc * prod(**ks), L) # 1. Reshape to img shape z = z.view( (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) ) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim if isinstance(self.first_stage_model, VQModelInterface): output_list = [ self.first_stage_model.decode( z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize, ) for i in range(z.shape[-1]) ] else: output_list = [ self.first_stage_model.decode(z[:, :, :, :, i]) for i in range(z.shape[-1]) ] o = torch.stack( output_list, axis=-1 ) # # (bn, nc, ks[0], ks[1], L) o = o * weighting # Reverse 1. reshape to img shape o = o.view( (o.shape[0], -1, o.shape[-1]) ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization # norm is shape (1, 1, h, w) return decoded else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode( z, force_not_quantize=predict_cids or force_not_quantize, ) else: return self.first_stage_model.decode(z) else: if isinstance(self.first_stage_model, VQModelInterface): return self.first_stage_model.decode( z, force_not_quantize=predict_cids or force_not_quantize ) else: return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): if hasattr(self, 'split_input_params'): if self.split_input_params['patch_distributed_vq']: ks = self.split_input_params['ks'] # eg. (128, 128) stride = self.split_input_params['stride'] # eg. (64, 64) df = self.split_input_params['vqf'] self.split_input_params['original_image_size'] = x.shape[-2:] bs, nc, h, w = x.shape if ks[0] > h or ks[1] > w: ks = (min(ks[0], h), min(ks[1], w)) print('reducing Kernel') if stride[0] > h or stride[1] > w: stride = (min(stride[0], h), min(stride[1], w)) print('reducing stride') fold, unfold, normalization, weighting = self.get_fold_unfold( x, ks, stride, df=df ) z = unfold(x) # (bn, nc * prod(**ks), L) # Reshape to img shape z = z.view( (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) ) # (bn, nc, ks[0], ks[1], L ) output_list = [ self.first_stage_model.encode(z[:, :, :, :, i]) for i in range(z.shape[-1]) ] o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape o = o.view( (o.shape[0], -1, o.shape[-1]) ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together decoded = fold(o) decoded = decoded / normalization return decoded else: return self.first_stage_model.encode(x) else: return self.first_stage_model.encode(x) def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) loss = self(x, c) return loss def forward(self, x, c, *args, **kwargs): t = torch.randint( 0, self.num_timesteps, (x.shape[0],), device=self.device ).long() if self.model.conditioning_key is not None: assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample( x_start=c, t=tc, noise=torch.randn_like(c.float()) ) return self.p_losses(x, c, t, *args, **kwargs) def _rescale_annotations( self, bboxes, crop_coordinates ): # TODO: move to dataset def rescale_bbox(bbox): x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) w = min(bbox[2] / crop_coordinates[2], 1 - x0) h = min(bbox[3] / crop_coordinates[3], 1 - y0) return x0, y0, w, h return [rescale_bbox(b) for b in bboxes] def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): # hybrid case, cond is exptected to be a dict pass else: if not isinstance(cond, list): cond = [cond] key = ( 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' ) cond = {key: cond} if hasattr(self, 'split_input_params'): assert ( len(cond) == 1 ) # todo can only deal with one conditioning atm assert not return_ids ks = self.split_input_params['ks'] # eg. (128, 128) stride = self.split_input_params['stride'] # eg. (64, 64) h, w = x_noisy.shape[-2:] fold, unfold, normalization, weighting = self.get_fold_unfold( x_noisy, ks, stride ) z = unfold(x_noisy) # (bn, nc * prod(**ks), L) # Reshape to img shape z = z.view( (z.shape[0], -1, ks[0], ks[1], z.shape[-1]) ) # (bn, nc, ks[0], ks[1], L ) z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] if ( self.cond_stage_key in ['image', 'LR_image', 'segmentation', 'bbox_img'] and self.model.conditioning_key ): # todo check for completeness c_key = next(iter(cond.keys())) # get key c = next(iter(cond.values())) # get value assert ( len(c) == 1 ) # todo extend to list with more than one elem c = c[0] # get element c = unfold(c) c = c.view( (c.shape[0], -1, ks[0], ks[1], c.shape[-1]) ) # (bn, nc, ks[0], ks[1], L ) cond_list = [ {c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1]) ] elif self.cond_stage_key == 'coordinates_bbox': assert ( 'original_image_size' in self.split_input_params ), 'BoudingBoxRescaling is missing original_image_size' # assuming padding of unfold is always 0 and its dilation is always 1 n_patches_per_row = int((w - ks[0]) / stride[0] + 1) full_img_h, full_img_w = self.split_input_params[ 'original_image_size' ] # as we are operating on latents, we need the factor from the original image size to the # spatial latent size to properly rescale the crops for regenerating the bbox annotations num_downs = self.first_stage_model.encoder.num_resolutions - 1 rescale_latent = 2 ** (num_downs) # get top left postions of patches as conforming for the bbbox tokenizer, therefore we # need to rescale the tl patch coordinates to be in between (0,1) tl_patch_coordinates = [ ( rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h, ) for patch_nr in range(z.shape[-1]) ] # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) patch_limits = [ ( x_tl, y_tl, rescale_latent * ks[0] / full_img_w, rescale_latent * ks[1] / full_img_h, ) for x_tl, y_tl in tl_patch_coordinates ] # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] # tokenize crop coordinates for the bounding boxes of the respective patches patch_limits_tknzd = [ torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[ None ].to(self.device) for bbox in patch_limits ] # list of length l with tensors of shape (1, 2) print(patch_limits_tknzd[0].shape) # cut tknzd crop position from conditioning assert isinstance( cond, dict ), 'cond must be dict to be fed into model' cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) print(cut_cond.shape) adapted_cond = torch.stack( [ torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd ] ) adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') print(adapted_cond.shape) adapted_cond = self.get_learned_conditioning(adapted_cond) print(adapted_cond.shape) adapted_cond = rearrange( adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1] ) print(adapted_cond.shape) cond_list = [{'c_crossattn': [e]} for e in adapted_cond] else: cond_list = [ cond for i in range(z.shape[-1]) ] # Todo make this more efficient # apply model by loop over crops output_list = [ self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1]) ] assert not isinstance( output_list[0], tuple ) # todo cant deal with multiple model outputs check this never happens o = torch.stack(output_list, axis=-1) o = o * weighting # Reverse reshape to img shape o = o.view( (o.shape[0], -1, o.shape[-1]) ) # (bn, nc * ks[0] * ks[1], L) # stitch crops together x_recon = fold(o) / normalization else: x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart ) / extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in bits-per-dim. This term can't be optimized, as it only depends on the encoder. :param x_start: the [N x C x ...] tensor of inputs. :return: a batch of [N] KL values (in bits), one per batch element. """ batch_size = x_start.shape[0] t = torch.tensor( [self.num_timesteps - 1] * batch_size, device=x_start.device ) qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) kl_prior = normal_kl( mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 ) return mean_flat(kl_prior) / np.log(2.0) def p_losses(self, x_start, cond, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} prefix = 'train' if self.training else 'val' if self.parameterization == 'x0': target = x_start elif self.parameterization == 'eps': target = noise else: raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean( [1, 2, 3] ) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) loss_dict.update({'logvar': self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean( dim=(1, 2, 3) ) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) loss += self.original_elbo_weight * loss_vlb loss_dict.update({f'{prefix}/loss': loss}) if self.embedding_reg_weight > 0: loss_embedding_reg = ( self.embedding_manager.embedding_to_coarse_loss().mean() ) loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg}) loss += self.embedding_reg_weight * loss_embedding_reg loss_dict.update({f'{prefix}/loss': loss}) return loss, loss_dict def p_mean_variance( self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, return_x0=False, score_corrector=None, corrector_kwargs=None, ): t_in = t model_out = self.apply_model( x, t_in, c, return_ids=return_codebook_ids ) if score_corrector is not None: assert self.parameterization == 'eps' model_out = score_corrector.modify_score( self, model_out, x, t, c, **corrector_kwargs ) if return_codebook_ids: model_out, logits = model_out if self.parameterization == 'eps': x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == 'x0': x_recon = model_out else: raise NotImplementedError() if clip_denoised: x_recon.clamp_(-1.0, 1.0) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize( x_recon ) ( model_mean, posterior_variance, posterior_log_variance, ) = self.q_posterior(x_start=x_recon, x_t=x, t=t) if return_codebook_ids: return ( model_mean, posterior_variance, posterior_log_variance, logits, ) elif return_x0: return ( model_mean, posterior_variance, posterior_log_variance, x_recon, ) else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample( self, x, c, t, clip_denoised=False, repeat_noise=False, return_codebook_ids=False, quantize_denoised=False, return_x0=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, ): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance( x=x, c=c, t=t, clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if return_codebook_ids: raise DeprecationWarning('Support dropped.') model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs else: model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape( b, *((1,) * (len(x.shape) - 1)) ) if return_codebook_ids: return model_mean + nonzero_mask * ( 0.5 * model_log_variance ).exp() * noise, logits.argmax(dim=1) if return_x0: return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0, ) else: return ( model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise ) @torch.no_grad() def progressive_denoising( self, cond, shape, verbose=True, callback=None, quantize_denoised=False, img_callback=None, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps if batch_size is not None: b = batch_size if batch_size is not None else shape[0] shape = [batch_size] + list(shape) else: b = batch_size = shape[0] if x_T is None: img = torch.randn(shape, device=self.device) else: img = x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm( reversed(range(0, timesteps)), desc='Progressive Generation', total=timesteps, ) if verbose else reversed(range(0, timesteps)) ) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample( x_start=cond, t=tc, noise=torch.randn_like(cond) ) img, x0_partial = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, return_x0=True, temperature=temperature[i], noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, ) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) if callback: callback(i) if img_callback: img_callback(img, i) return img, intermediates @torch.no_grad() def p_sample_loop( self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None, ): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps if start_T is not None: timesteps = min(timesteps, start_T) iterator = ( tqdm( reversed(range(0, timesteps)), desc='Sampling t', total=timesteps, ) if verbose else reversed(range(0, timesteps)) ) if mask is not None: assert x0 is not None assert ( x0.shape[2:3] == mask.shape[2:3] ) # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample( x_start=cond, t=tc, noise=torch.randn_like(cond) ) img = self.p_sample( img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, ) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1.0 - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) if callback: callback(i) if img_callback: img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample( self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None, **kwargs, ): if shape is None: shape = ( batch_size, self.channels, self.image_size, self.image_size, ) if cond is not None: if isinstance(cond, dict): cond = { key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond } else: cond = ( [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] ) return self.p_sample_loop( cond, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0, ) @torch.no_grad() def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) samples, intermediates = ddim_sampler.sample( ddim_steps, batch_size, shape, cond, verbose=False, **kwargs ) else: samples, intermediates = self.sample( cond=cond, batch_size=batch_size, return_intermediates=True, **kwargs, ) return samples, intermediates @torch.no_grad() def log_images( self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1.0, return_keys=None, quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, plot_diffusion_rows=False, **kwargs, ): use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc = self.get_input( batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=N, ) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log['inputs'] = x log['reconstruction'] = xrec if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, 'decode'): xc = self.cond_stage_model.decode(c) log['conditioning'] = xc elif self.cond_stage_key in ['caption']: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch['caption']) log['conditioning'] = xc elif self.cond_stage_key == 'class_label': xc = log_txt_as_img( (x.shape[2], x.shape[3]), batch['human_label'] ) log['conditioning'] = xc elif isimage(xc): log['conditioning'] = xc if ismap(xc): log['original_conditioning'] = self.to_rgb(xc) if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), '1 -> b', b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack( diffusion_row ) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange( diffusion_grid, 'b n c h w -> (b n) c h w' ) diffusion_grid = make_grid( diffusion_grid, nrow=diffusion_row.shape[0] ) log['diffusion_row'] = diffusion_grid if sample: # get denoise row with self.ema_scope('Plotting'): samples, z_denoise_row = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log['samples'] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log['denoise_row'] = denoise_grid uc = self.get_learned_conditioning(len(c) * ['']) sample_scaled, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, unconditional_guidance_scale=5.0, unconditional_conditioning=uc, ) log['samples_scaled'] = self.decode_first_stage(sample_scaled) if ( quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(self.first_stage_model, IdentityFirstStage) ): # also display when quantizing x0 while sampling with self.ema_scope('Plotting Quantized Denoised'): samples, z_denoise_row = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, ddim_steps=ddim_steps, eta=ddim_eta, quantize_denoised=True, ) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) log['samples_x0_quantized'] = x_samples if inpaint: # make a simple center square b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4 : 3 * h // 4, w // 4 : 3 * w // 4] = 0.0 mask = mask[:, None, ...] with self.ema_scope('Plotting Inpaint'): samples, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask, ) x_samples = self.decode_first_stage(samples.to(self.device)) log['samples_inpainting'] = x_samples log['mask'] = mask # outpaint with self.ema_scope('Plotting Outpaint'): samples, _ = self.sample_log( cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask, ) x_samples = self.decode_first_stage(samples.to(self.device)) log['samples_outpainting'] = x_samples if plot_progressive_rows: with self.ema_scope('Plotting Progressives'): img, progressives = self.progressive_denoising( c, shape=(self.channels, self.image_size, self.image_size), batch_size=N, ) prog_row = self._get_denoise_row_from_list( progressives, desc='Progressive Generation' ) log['progressive_row'] = prog_row if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate if self.embedding_manager is not None: params = list(self.embedding_manager.embedding_parameters()) # params = list(self.cond_stage_model.transformer.text_model.embeddings.embedding_manager.embedding_parameters()) else: params = list(self.model.parameters()) if self.cond_stage_trainable: print( f'{self.__class__.__name__}: Also optimizing conditioner params!' ) params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print('Setting up LambdaLR scheduler...') scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1, } ] return [opt], scheduler return opt @torch.no_grad() def to_rgb(self, x): x = x.float() if not hasattr(self, 'colorize'): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 return x @rank_zero_only def on_save_checkpoint(self, checkpoint): checkpoint.clear() if os.path.isdir(self.trainer.checkpoint_callback.dirpath): self.embedding_manager.save( os.path.join( self.trainer.checkpoint_callback.dirpath, 'embeddings.pt' ) ) if (self.global_step - self.emb_ckpt_counter) > 500: self.embedding_manager.save( os.path.join( self.trainer.checkpoint_callback.dirpath, f'embeddings_gs-{self.global_step}.pt', ) ) self.emb_ckpt_counter += 500 class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [ None, 'concat', 'crossattn', 'hybrid', 'adm', ] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == 'crossattn': cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) else: raise NotImplementedError() return out class Layout2ImgDiffusion(LatentDiffusion): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert ( cond_stage_key == 'coordinates_bbox' ), 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): logs = super().log_images(batch=batch, N=N, *args, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] mapper = dset.conditional_builders[self.cond_stage_key] bbox_imgs = [] map_fn = lambda catno: dset.get_textual_label( dset.get_category_id(catno) ) for tknzd_bbox in batch[self.cond_stage_key][:N]: bboximg = mapper.plot( tknzd_bbox.detach().cpu(), map_fn, (256, 256) ) bbox_imgs.append(bboximg) cond_img = torch.stack(bbox_imgs, dim=0) logs['bbox_image'] = cond_img return logs ================================================ FILE: src/stablediffusion/ldm/models/diffusion/ksampler.py ================================================ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" import k_diffusion as K import torch import torch.nn as nn from src.stablediffusion.ldm.dream.devices import choose_torch_device class CFGDenoiser(nn.Module): def __init__(self, model): super().__init__() self.inner_model = model def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) return uncond + (cond - uncond) * cond_scale class KSampler(object): def __init__(self, model, schedule='lms', device=None, **kwargs): super().__init__() self.model = K.external.CompVisDenoiser(model) self.schedule = schedule self.device = device or choose_torch_device() def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) uncond, cond = self.inner_model( x_in, sigma_in, cond=cond_in ).chunk(2) return uncond + (cond - uncond) * cond_scale # most of these arguments are ignored and are only present for compatibility with # other samples @torch.no_grad() def sample( self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): def route_callback(k_callback_values): if img_callback is not None: img_callback(k_callback_values['x'], k_callback_values['i']) sigmas = self.model.get_sigmas(S) if x_T is not None: x = x_T * sigmas[0] else: x = ( torch.randn([batch_size, *shape], device=self.device) * sigmas[0] ) # for GPU draw model_wrap_cfg = CFGDenoiser(self.model) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } return ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback ), None, ) ================================================ FILE: src/stablediffusion/ldm/models/diffusion/plms.py ================================================ """SAMPLING ONLY.""" import torch import numpy as np from tqdm import tqdm from functools import partial from src.stablediffusion.ldm.dream.devices import choose_torch_device from src.stablediffusion.ldm.modules.diffusionmodules.util import ( make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, ) class PLMSSampler(object): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule self.device = device if device else choose_torch_device() def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device(self.device): attr = attr.to(torch.float32).to(torch.device(self.device)) setattr(self, name, attr) def make_schedule( self, ddim_num_steps, ddim_discretize='uniform', ddim_eta=0.0, verbose=True, ): if ddim_eta != 0: raise ValueError('ddim_eta must be 0 for PLMS') self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose, ) alphas_cumprod = self.model.alphas_cumprod assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), 'alphas have to be defined for each timestep' to_torch = ( lambda x: x.clone() .detach() .to(torch.float32) .to(self.model.device) ) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer( 'alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev) ) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( 'sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())) ) self.register_buffer( 'sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), ) self.register_buffer( 'log_one_minus_alphas_cumprod', to_torch(np.log(1.0 - alphas_cumprod.cpu())), ) self.register_buffer( 'sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())), ) self.register_buffer( 'sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), ) # ddim sampling parameters ( ddim_sigmas, ddim_alphas, ddim_alphas_prev, ) = make_ddim_sampling_parameters( alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose, ) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer( 'ddim_sqrt_one_minus_alphas', np.sqrt(1.0 - ddim_alphas) ) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) ) self.register_buffer( 'ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps, ) @torch.no_grad() def sample( self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0.0, mask=None, x0=None, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: print( f'Warning: Got {cbs} conditionings but batch-size is {batch_size}' ) else: if conditioning.shape[0] != batch_size: print( f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}' ) self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) # print(f'Data shape for PLMS sampling is {size}') samples, intermediates = self.plms_sampling( conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, ) return samples, intermediates @torch.no_grad() def plms_sampling( self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if timesteps is None: timesteps = ( self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps ) elif timesteps is not None and not ddim_use_original_steps: subset_end = ( int( min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0] ) - 1 ) timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} time_range = ( list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps) ) total_steps = ( timesteps if ddim_use_original_steps else timesteps.shape[0] ) # print(f"Running PLMS Sampling with {total_steps} timesteps") iterator = tqdm( time_range, desc='PLMS Sampler', total=total_steps, dynamic_ncols=True, ) old_eps = [] for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long, ) if mask is not None: assert x0 is not None img_orig = self.model.q_sample( x0, ts ) # TODO: deterministic forward pass? img = img_orig * mask + (1.0 - mask) * img outs = self.p_sample_plms( img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, ) img, pred_x0, e_t = outs old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) return img, intermediates @torch.no_grad() def p_sample_plms( self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1.0, noise_dropout=0.0, score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, old_eps=None, t_next=None, ): b, *_, device = *x.shape, x.device def get_model_output(x, t): if ( unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model( x_in, t_in, c_in ).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) if score_corrector is not None: assert self.model.parameterization == 'eps' e_t = score_corrector.modify_score( self.model, e_t, x, t, c, **corrector_kwargs ) return e_t alphas = ( self.model.alphas_cumprod if use_original_steps else self.ddim_alphas ) alphas_prev = ( self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev ) sqrt_one_minus_alphas = ( self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas ) sigmas = ( self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas ) def get_x_prev_and_pred_x0(e_t, index): # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full( (b, 1, 1, 1), alphas_prev[index], device=device ) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full( (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device ) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t noise = ( sigma_t * noise_like(x.shape, device, repeat_noise) * temperature ) if noise_dropout > 0.0: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 e_t = get_model_output(x, t) if len(old_eps) == 0: # Pseudo Improved Euler (2nd order) x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) e_t_next = get_model_output(x_prev, t_next) e_t_prime = (e_t + e_t_next) / 2 elif len(old_eps) == 1: # 2nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (3 * e_t - old_eps[-1]) / 2 elif len(old_eps) == 2: # 3nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 elif len(old_eps) >= 3: # 4nd order Pseudo Linear Multistep (Adams-Bashforth) e_t_prime = ( 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3] ) / 24 x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) return x_prev, pred_x0, e_t ================================================ FILE: src/stablediffusion/ldm/modules/attention.py ================================================ from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from src.stablediffusion.ldm.modules.diffusionmodules.util import checkpoint import psutil def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) return self.to_out(out) class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = rearrange(v, 'b c h w -> b c (h w)') w_ = rearrange(w_, 'b i j -> b j i') h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) return x+h_ class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) if not torch.cuda.is_available(): mem_av = psutil.virtual_memory().available / (1024**3) if mem_av > 32: self.einsum_op = self.einsum_op_v1 elif mem_av > 12: self.einsum_op = self.einsum_op_v2 else: self.einsum_op = self.einsum_op_v3 del mem_av else: self.einsum_op = self.einsum_op_v4 # mps 64-128 GB def einsum_op_v1(self, q, k, v, r1): if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096 s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 r1 = einsum('b i j, b j d -> b i d', s2, v) del s2 else: # q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err # needs around half of that slice_size to not generate noise slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 return r1 # mps 16-32 GB (can be optimized) def einsum_op_v2(self, q, k, v, r1): slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 return r1 # mps 8 GB def einsum_op_v3(self, q, k, v, r1): slice_size = 1 for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0] end = min(q.shape[0], i + slice_size) s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem s1 *= self.scale s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem del s2 return r1 # cuda def einsum_op_v4(self, q, k, v, r1): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = min(q.shape[1], i + slice_size) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s2 = s1.softmax(dim=-1, dtype=r1.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 return r1 def forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) v_in = self.to_v(context) device_type = 'mps' if x.device.type == 'mps' else 'cuda' del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = self.einsum_op(q, k, v, r1) del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 return self.to_out(r2) class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = x.contiguous() if x.device.type == 'mps' else x x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention b, c, h, w = x.shape x_in = x x = self.norm(x) x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c') for block in self.transformer_blocks: x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in ================================================ FILE: src/stablediffusion/ldm/modules/diffusionmodules/__init__.py ================================================ ================================================ FILE: src/stablediffusion/ldm/modules/diffusionmodules/model.py ================================================ # pytorch_diffusion + derived encoder decoder import gc import math import torch import torch.nn as nn import numpy as np from einops import rearrange from src.stablediffusion.ldm.util import instantiate_from_config from src.stablediffusion.ldm.modules.attention import LinearAttention import psutil def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb def nonlinearity(x): # swish return x*torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h1 = x h2 = self.norm1(h1) del h1 h3 = nonlinearity(h2) del h2 h4 = self.conv1(h3) del h3 if temb is not None: h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] h5 = self.norm2(h4) del h4 h6 = nonlinearity(h5) del h5 h7 = self.dropout(h6) del h6 h8 = self.conv2(h7) del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h8 class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_) k1 = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q1.shape q2 = q1.reshape(b, c, h*w) del q1 q = q2.permute(0, 2, 1) # b,hw,c del q2 k = k1.reshape(b, c, h*w) # b,c,hw del k1 h_ = torch.zeros_like(k, device=q.device) device_type = 'mps' if q.device.type == 'mps' else 'cuda' if device_type == 'cuda': stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] else: if psutil.virtual_memory().available / (1024**3) < 12: slice_size = 1 else: slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1]))) for i in range(0, q.shape[1], slice_size): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c)**(-0.5)) del w1 w3 = torch.nn.functional.softmax(w2, dim=2) del w2 # attend to values v1 = v.reshape(b, c, h*w) w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) del w3 h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ h3 = self.proj_out(h2) del h2 h3 += x return h3 def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_timestep = use_timestep if self.use_timestep: # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList([ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ]) # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] skip_in = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): if i_block == self.num_res_blocks: skip_in = ch*in_ch_mult[i_level] block.append(ResnetBlock(in_channels=block_in+skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x, t=None, context=None): #assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) if self.use_timestep: # timestep embedding assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) temb = nonlinearity(temb) temb = self.temb.dense[1](temb) else: temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block]( torch.cat([h, hs.pop()], dim=1), temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h def get_last_layer(self): return self.conv_out.weight class Encoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch*in_ch_mult[i_level] block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, 2*z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,)+tuple(ch_mult) block_in = ch*ch_mult[self.num_resolutions-1] curr_res = resolution // 2**(self.num_resolutions-1) self.z_shape = (1,z_channels,curr_res,curr_res) print("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch*ch_mult[i_level] for i_block in range(self.num_res_blocks+1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z): #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h1 = self.conv_in(z) # middle h2 = self.mid.block_1(h1, temb) del h1 h3 = self.mid.attn_1(h2) del h2 h = self.mid.block_2(h3, temb) del h3 # prepare for up sampling device_type = 'mps' if h.device.type == 'mps' else 'cuda' gc.collect() if device_type == 'cuda': torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: t = h h = self.up[i_level].attn[i_block](t) del t if i_level != 0: t = h h = self.up[i_level].upsample(t) del t # end if self.give_pre_end: return h h1 = self.norm_out(h) del h h2 = nonlinearity(h1) del h1 h = self.conv_out(h2) del h2 if self.tanh_out: t = h h = torch.tanh(t) del t return h class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), ResnetBlock(in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), ResnetBlock(in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0), ResnetBlock(in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0), nn.Conv2d(2*in_channels, in_channels, 1), Upsample(in_channels, with_conv=True)]) # end self.norm_out = Normalize(in_channels) self.conv_out = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): for i, layer in enumerate(self.model): if i in [1,2,3]: x = layer(x, None) else: x = layer(x) h = self.norm_out(x) h = nonlinearity(h) x = self.conv_out(h) return x class UpsampleDecoder(nn.Module): def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2,2), dropout=0.0): super().__init__() # upsampling self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels curr_res = resolution // 2 ** (self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): res_block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: self.upsample_blocks.append(Upsample(block_in, True)) curr_res = curr_res * 2 # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # upsampling h = x for k, i_level in enumerate(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.res_blocks[i_level][i_block](h, None) if i_level != self.num_resolutions - 1: h = self.upsample_blocks[k](h) h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) return h class LatentRescaler(nn.Module): def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1) self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) for _ in range(depth)]) self.attn = AttnBlock(mid_channels) self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0) for _ in range(depth)]) self.conv_out = nn.Conv2d(mid_channels, out_channels, kernel_size=1, ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) x = self.attn(x) for block in self.res_block2: x = block(x, None) x = self.conv_out(x) return x class MergedRescaleEncoder(nn.Module): def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): super().__init__() intermediate_chn = ch * ch_mult[-1] self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, z_channels=intermediate_chn, double_z=False, resolution=resolution, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, out_ch=None) self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) def forward(self, x): x = self.encoder(x) x = self.rescaler(x) return x class MergedRescaleDecoder(nn.Module): def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): super().__init__() tmp_chn = z_channels*ch_mult[-1] self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, ch_mult=ch_mult, resolution=resolution, ch=ch) self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, out_channels=tmp_chn, depth=rescale_module_depth) def forward(self, x): x = self.rescaler(x) x = self.decoder(x) return x class Upsampler(nn.Module): def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size//in_size))+1 factor_up = 1.+ (out_size % in_size) print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, out_channels=in_channels) self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, attn_resolutions=[], in_channels=None, ch=in_channels, ch_mult=[ch_mult for _ in range(num_blocks)]) def forward(self, x): x = self.rescaler(x) x = self.decoder(x) return x class Resize(nn.Module): def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=4, stride=2, padding=1) def forward(self, x, scale_factor=1.0): if scale_factor==1.0: return x else: x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) return x class FirstStagePostProcessor(nn.Module): def __init__(self, ch_mult:list, in_channels, pretrained_model:nn.Module=None, reshape=False, n_channels=None, dropout=0., pretrained_config=None): super().__init__() if pretrained_config is None: assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape if n_channels is None: n_channels = self.pretrained_model.encoder.ch self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, stride=1,padding=1) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() # self.pretrained_model.train = False for param in self.pretrained_model.parameters(): param.requires_grad = False @torch.no_grad() def encode_with_pretrained(self,x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() return c def forward(self,x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = nonlinearity(z) for submodel, downmodel in zip(self.model,self.downsampler): z = submodel(z,temb=None) z = downmodel(z) if self.do_reshape: z = rearrange(z,'b c h w -> b (h w) c') return z ================================================ FILE: src/stablediffusion/ldm/modules/diffusionmodules/openaimodel.py ================================================ from abc import abstractmethod from functools import partial import math from typing import Iterable import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from src.stablediffusion.ldm.modules.diffusionmodules.util import ( checkpoint, conv_nd, linear, avg_pool_nd, zero_module, normalization, timestep_embedding, ) from src.stablediffusion.ldm.modules.attention import SpatialTransformer # dummy replace def convert_module_to_f16(x): pass def convert_module_to_f32(x): pass ## go class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py """ def __init__( self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None, ): super().__init__() self.positional_embedding = nn.Parameter( th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 ) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels self.attention = QKVAttention(self.num_heads) def forward(self, x): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) # NC(HW) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) return x[:, :, 0] class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x, emb): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__( self, channels, use_conv, dims=2, out_channels=None, padding=1 ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd( dims, self.channels, self.out_channels, 3, padding=padding ) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest' ) else: x = F.interpolate(x, scale_factor=2, mode='nearest') if self.use_conv: x = self.conv(x) return x class TransposedUpsample(nn.Module): """Learned 2x upsampling without padding""" def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.up = nn.ConvTranspose2d( self.channels, self.out_channels, kernel_size=ks, stride=2 ) def forward(self, x): return self.up(x) class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__( self, channels, use_conv, dims=2, out_channels=None, padding=1 ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd( dims, self.out_channels, self.out_channels, 3, padding=1 ) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd( dims, channels, self.out_channels, 1 ) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ return checkpoint( self._forward, (x, emb), self.parameters(), self.use_checkpoint ) def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__( self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False, use_new_attention_order=False, ): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) else: # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): return checkpoint( self._forward, (x,), self.parameters(), True ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! # return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) def count_flops_attn(model, _x, y): """ A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: macs, params = thop.profile( model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, ) """ b, c, *spatial = y[0].shape num_spatial = int(np.prod(spatial)) # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. matmul_ops = 2 * b * (num_spatial**2) * c model.total_ops += th.DoubleTensor([matmul_ops]) class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split( ch, dim=1 ) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( 'bct,bcs->bts', q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) @staticmethod def count_flops(model, _x, y): return count_flops_attn(model, _x, y) class QKVAttention(nn.Module): """ A module which performs QKV attention and splits in a different order. """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): """ Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( 'bct,bcs->bts', (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum( 'bts,bcs->bct', weight, v.reshape(bs * self.n_heads, ch, length) ) return a.reshape(bs, -1, length) @staticmethod def count_flops(model, _x, y): return count_flops_attn(model, _x, y) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, use_spatial_transformer=False, # custom transformer support transformer_depth=1, # custom transformer support context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, ): super().__init__() if use_spatial_transformer: assert ( context_dim is not None ), 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert ( use_spatial_transformer ), 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' from omegaconf.listconfig import ListConfig if type(context_dim) == ListConfig: context_dim = list(context_dim) if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert ( num_head_channels != -1 ), 'Either num_heads or num_head_channels has to be set' if num_head_channels == -1: assert ( num_heads != -1 ), 'Either num_heads or num_head_channels has to be set' self.image_size = image_size self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.num_classes = num_classes self.use_checkpoint = use_checkpoint self.dtype = th.float16 if use_fp16 else th.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample self.predict_codebook_ids = n_embed is not None time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ( ch // num_heads if use_spatial_transformer else num_head_channels ) layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ( ch // num_heads if use_spatial_transformer else num_head_channels ) self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: # num_heads = 1 dim_head = ( ch // num_heads if use_spatial_transformer else num_head_channels ) layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, ) ) if level and i == num_res_blocks: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module( conv_nd(dims, model_channels, out_channels, 3, padding=1) ), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( normalization(ch), conv_nd(dims, model_channels, n_embed, 1), # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) self.output_blocks.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), 'must specify y if and only if the model is class-conditional' hs = [] t_emb = timestep_embedding( timesteps, self.model_channels, repeat_only=False ) emb = self.time_embed(t_emb) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) class EncoderUNetModel(nn.Module): """ The half UNet model with attention and timestep embedding. For usage, see UNet. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, use_checkpoint=False, use_fp16=False, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, pool='adaptive', *args, **kwargs, ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint self.dtype = th.float16 if use_fp16 else th.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.pool = pool if pool == 'adaptive': self.out = nn.Sequential( normalization(ch), nn.SiLU(), nn.AdaptiveAvgPool2d((1, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)), nn.Flatten(), ) elif pool == 'attention': assert num_head_channels != -1 self.out = nn.Sequential( normalization(ch), nn.SiLU(), AttentionPool2d( (image_size // ds), ch, num_head_channels, out_channels ), ) elif pool == 'spatial': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), ) elif pool == 'spatial_v2': self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), normalization(2048), nn.SiLU(), nn.Linear(2048, self.out_channels), ) else: raise NotImplementedError(f'Unexpected {pool} pooling') def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) def forward(self, x, timesteps): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ emb = self.time_embed( timestep_embedding(timesteps, self.model_channels) ) results = [] h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb) if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb) if self.pool.startswith('spatial'): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) return self.out(h) else: h = h.type(x.dtype) return self.out(h) ================================================ FILE: src/stablediffusion/ldm/modules/diffusionmodules/util.py ================================================ # adopted from # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py # and # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py # and # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py # # thanks! import os import math import torch import torch.nn as nn import numpy as np from einops import repeat from src.stablediffusion.ldm.util import instantiate_from_config def make_beta_schedule( schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 ): if schedule == 'linear': betas = ( torch.linspace( linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64, ) ** 2 ) elif schedule == 'cosine': timesteps = ( torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == 'sqrt_linear': betas = torch.linspace( linear_start, linear_end, n_timestep, dtype=torch.float64 ) elif schedule == 'sqrt': betas = ( torch.linspace( linear_start, linear_end, n_timestep, dtype=torch.float64 ) ** 0.5 ) else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def make_ddim_timesteps( ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True ): if ddim_discr_method == 'uniform': c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': ddim_timesteps = ( ( np.linspace( 0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps ) ) ** 2 ).astype(int) else: raise NotImplementedError( f'There is no ddim discretization method called "{ddim_discr_method}"' ) # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) # steps_out = ddim_timesteps + 1 steps_out = ddim_timesteps if verbose: print(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out def make_ddim_sampling_parameters( alphacums, ddim_timesteps, eta, verbose=True ): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] alphas_prev = np.asarray( [alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist() ) # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt( (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) ) if verbose: print( f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}' ) print( f'For the chosen value of eta, which is {eta}, ' f'this results in the following sigma_t schedule for ddim sampler {sigmas}' ) return sigmas, alphas, alphas_prev def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if ( False ): # disabled checkpointing to allow requires_grad = False for main model args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [ x.detach().requires_grad_(True) for x in ctx.input_tensors ] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(32, channels) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}') def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f'unsupported dimensions: {dims}') class HybridConditioner(nn.Module): def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) self.crossattn_conditioner = instantiate_from_config( c_crossattn_config ) def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( shape[0], *((1,) * (len(shape) - 1)) ) noise = lambda: torch.randn(shape, device=device) return repeat_noise() if repeat else noise() ================================================ FILE: src/stablediffusion/ldm/modules/distributions/__init__.py ================================================ ================================================ FILE: src/stablediffusion/ldm/modules/distributions/distributions.py ================================================ import torch import numpy as np class AbstractDistribution: def sample(self): raise NotImplementedError() def mode(self): raise NotImplementedError() class DiracDistribution(AbstractDistribution): def __init__(self, value): self.value = value def sample(self): return self.value def mode(self): return self.value class DiagonalGaussianDistribution(object): def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like(self.mean).to( device=self.parameters.device ) def sample(self): x = self.mean + self.std * torch.randn(self.mean.shape).to( device=self.parameters.device ) return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean def normal_kl(mean1, logvar1, mean2, logvar2): """ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 Compute the KL divergence between two gaussians. Shapes are automatically broadcasted, so batches can be compared to scalars, among other use cases. """ tensor = None for obj in (mean1, logvar1, mean2, logvar2): if isinstance(obj, torch.Tensor): tensor = obj break assert tensor is not None, 'at least one argument must be a Tensor' # Force variances to be Tensors. Broadcasting helps convert scalars to # Tensors, but it does not work for torch.exp(). logvar1, logvar2 = [ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2) ] return 0.5 * ( -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) ) ================================================ FILE: src/stablediffusion/ldm/modules/ema.py ================================================ import torch from torch import nn class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer( 'num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), ) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers s_name = name.replace('.', '') self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 decay = min( self.decay, (1 + self.num_updates) / (10 + self.num_updates) ) one_minus_decay = 1.0 - decay with torch.no_grad(): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as( m_param[key] ) shadow_params[sname].sub_( one_minus_decay * (shadow_params[sname] - m_param[key]) ) else: assert not key in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_( shadow_params[self.m_name2s_name[key]].data ) else: assert not key in self.m_name2s_name def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) ================================================ FILE: src/stablediffusion/ldm/modules/embedding_manager.py ================================================ from cmath import log import torch from torch import nn import sys from src.stablediffusion.ldm.data.personalized import per_img_token_list from transformers import CLIPTokenizer from functools import partial DEFAULT_PLACEHOLDER_TOKEN = ['*'] PROGRESSIVE_SCALE = 2000 def get_clip_token_for_string(tokenizer, string): batch_encoding = tokenizer( string, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding='max_length', return_tensors='pt', ) tokens = batch_encoding['input_ids'] """ assert ( torch.count_nonzero(tokens - 49407) == 2 ), f"String '{string}' maps to more than a single token. Please use another string" """ return tokens[0, 1] def get_bert_token_for_string(tokenizer, string): token = tokenizer(string) # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" token = token[0, 1] return token def get_embedding_for_clip_token(embedder, token): return embedder(token.unsqueeze(0))[0, 0] class EmbeddingManager(nn.Module): def __init__( self, embedder, placeholder_strings=None, initializer_words=None, per_image_tokens=False, num_vectors_per_token=1, progressive_words=False, **kwargs, ): super().__init__() self.embedder = embedder self.string_to_token_dict = {} self.string_to_param_dict = nn.ParameterDict() self.initial_embeddings = ( nn.ParameterDict() ) # These should not be optimized self.progressive_words = progressive_words self.progressive_counter = 0 self.max_vectors_per_token = num_vectors_per_token if hasattr( embedder, 'tokenizer' ): # using Stable Diffusion's CLIP encoder self.is_clip = True get_token_for_string = partial( get_clip_token_for_string, embedder.tokenizer ) get_embedding_for_tkn = partial( get_embedding_for_clip_token, embedder.transformer.text_model.embeddings, ) token_dim = 1280 else: # using LDM's BERT encoder self.is_clip = False get_token_for_string = partial( get_bert_token_for_string, embedder.tknz_fn ) get_embedding_for_tkn = embedder.transformer.token_emb token_dim = 1280 if per_image_tokens: placeholder_strings.extend(per_img_token_list) for idx, placeholder_string in enumerate(placeholder_strings): token = get_token_for_string(placeholder_string) if initializer_words and idx < len(initializer_words): init_word_token = get_token_for_string(initializer_words[idx]) with torch.no_grad(): init_word_embedding = get_embedding_for_tkn( init_word_token.cpu() ) token_params = torch.nn.Parameter( init_word_embedding.unsqueeze(0).repeat( num_vectors_per_token, 1 ), requires_grad=True, ) self.initial_embeddings[ placeholder_string ] = torch.nn.Parameter( init_word_embedding.unsqueeze(0).repeat( num_vectors_per_token, 1 ), requires_grad=False, ) else: token_params = torch.nn.Parameter( torch.rand( size=(num_vectors_per_token, token_dim), requires_grad=True, ) ) self.string_to_token_dict[placeholder_string] = token self.string_to_param_dict[placeholder_string] = token_params def forward( self, tokenized_text, embedded_text, ): b, n, device = *tokenized_text.shape, tokenized_text.device for ( placeholder_string, placeholder_token, ) in self.string_to_token_dict.items(): placeholder_embedding = self.string_to_param_dict[ placeholder_string ].to(device) if ( self.max_vectors_per_token == 1 ): # If there's only one vector per token, we can do a simple replacement placeholder_idx = torch.where( tokenized_text == placeholder_token.to(device) ) embedded_text[placeholder_idx] = placeholder_embedding else: # otherwise, need to insert and keep track of changing indices if self.progressive_words: self.progressive_counter += 1 max_step_tokens = ( 1 + self.progressive_counter // PROGRESSIVE_SCALE ) else: max_step_tokens = self.max_vectors_per_token num_vectors_for_token = min( placeholder_embedding.shape[0], max_step_tokens ) placeholder_rows, placeholder_cols = torch.where( tokenized_text == placeholder_token.to(device) ) if placeholder_rows.nelement() == 0: continue sorted_cols, sort_idx = torch.sort( placeholder_cols, descending=True ) sorted_rows = placeholder_rows[sort_idx] for idx in range(len(sorted_rows)): row = sorted_rows[idx] col = sorted_cols[idx] new_token_row = torch.cat( [ tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to( device ), tokenized_text[row][col + 1 :], ], axis=0, )[:n] new_embed_row = torch.cat( [ embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1 :], ], axis=0, )[:n] embedded_text[row] = new_embed_row tokenized_text[row] = new_token_row return embedded_text def save(self, ckpt_path): torch.save( { 'string_to_token': self.string_to_token_dict, 'string_to_param': self.string_to_param_dict, }, ckpt_path, ) def load(self, ckpt_path, full=True): ckpt = torch.load(ckpt_path, map_location='cpu') # Handle .pt textual inversion files if 'string_to_token' in ckpt and 'string_to_param' in ckpt: self.string_to_token_dict = ckpt["string_to_token"] self.string_to_param_dict = ckpt["string_to_param"] # Handle .bin textual inversion files from Huggingface Concepts # https://huggingface.co/sd-concepts-library else: for token_str in list(ckpt.keys()): token = get_clip_token_for_string(self.embedder.tokenizer, token_str) self.string_to_token_dict[token_str] = token ckpt[token_str] = torch.nn.Parameter(ckpt[token_str]) self.string_to_param_dict.update(ckpt) if not full: for key, value in self.string_to_param_dict.items(): self.string_to_param_dict[key] = torch.nn.Parameter(value.half()) print(f'Added terms: {", ".join(self.string_to_param_dict.keys())}') def get_embedding_norms_squared(self): all_params = torch.cat( list(self.string_to_param_dict.values()), axis=0 ) # num_placeholders x embedding_dim param_norm_squared = (all_params * all_params).sum( axis=-1 ) # num_placeholders return param_norm_squared def embedding_parameters(self): return self.string_to_param_dict.parameters() def embedding_to_coarse_loss(self): loss = 0.0 num_embeddings = len(self.initial_embeddings) for key in self.initial_embeddings: optimized = self.string_to_param_dict[key] coarse = self.initial_embeddings[key].clone().to(optimized.device) loss = ( loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings ) return loss ================================================ FILE: src/stablediffusion/ldm/modules/encoders/__init__.py ================================================ ================================================ FILE: src/stablediffusion/ldm/modules/encoders/modules.py ================================================ import torch import torch.nn as nn from functools import partial import clip from einops import rearrange, repeat from transformers import CLIPTokenizer, CLIPTextModel import kornia from src.stablediffusion.ldm.dream.devices import choose_torch_device from src.stablediffusion.ldm.modules.x_transformer import ( Encoder, TransformerWrapper, ) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test def _expand_mask(mask, dtype, tgt_len=None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = ( mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) ) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) def _build_causal_attention_mask(bsz, seq_len, dtype): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(1) # expand mask return mask class AbstractEncoder(nn.Module): def __init__(self): super().__init__() def encode(self, *args, **kwargs): raise NotImplementedError class ClassEmbedder(nn.Module): def __init__(self, embed_dim, n_classes=1000, key='class'): super().__init__() self.key = key self.embedding = nn.Embedding(n_classes, embed_dim) def forward(self, batch, key=None): if key is None: key = self.key # this is for use in crossattn c = batch[key][:, None] c = self.embedding(c) return c class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" def __init__( self, n_embed, n_layer, vocab_size, max_seq_len=77, device=choose_torch_device(), ): super().__init__() self.device = device self.transformer = TransformerWrapper( num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=Encoder(dim=n_embed, depth=n_layer), ) def forward(self, tokens): tokens = tokens.to(self.device) # meh z = self.transformer(tokens, return_embeddings=True) return z def encode(self, x): return self(x) class BERTTokenizer(AbstractEncoder): """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" def __init__( self, device=choose_torch_device(), vq_interface=True, max_length=77 ): super().__init__() from transformers import ( BertTokenizerFast, ) # TODO: add to reuquirements # Modified to allow to run on non-internet connected compute nodes. # Model needs to be loaded into cache from an internet-connected machine # by running: # from transformers import BertTokenizerFast # BertTokenizerFast.from_pretrained("bert-base-uncased") try: self.tokenizer = BertTokenizerFast.from_pretrained( 'bert-base-uncased', local_files_only=False ) except OSError: raise SystemExit( "* Couldn't load Bert tokenizer files. Try running scripts/preload_models.py from an internet-conected machine." ) self.device = device self.vq_interface = vq_interface self.max_length = max_length def forward(self, text): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding='max_length', return_tensors='pt', ) tokens = batch_encoding['input_ids'].to(self.device) return tokens @torch.no_grad() def encode(self, text): tokens = self(text) if not self.vq_interface: return tokens return None, None, [None, None, tokens] def decode(self, text): return text class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" def __init__( self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, device=choose_torch_device(), use_tokenizer=True, embedding_dropout=0.0, ): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: self.tknz_fn = BERTTokenizer( vq_interface=False, max_length=max_seq_len ) self.device = device self.transformer = TransformerWrapper( num_tokens=vocab_size, max_seq_len=max_seq_len, attn_layers=Encoder(dim=n_embed, depth=n_layer), emb_dropout=embedding_dropout, ) def forward(self, text, embedding_manager=None): if self.use_tknz_fn: tokens = self.tknz_fn(text) # .to(self.device) else: tokens = text z = self.transformer( tokens, return_embeddings=True, embedding_manager=embedding_manager ) return z def encode(self, text, **kwargs): # output of length 77 return self(text, **kwargs) class SpatialRescaler(nn.Module): def __init__( self, n_stages=1, method='bilinear', multiplier=0.5, in_channels=3, out_channels=None, bias=False, ): super().__init__() self.n_stages = n_stages assert self.n_stages >= 0 assert method in [ 'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area', ] self.multiplier = multiplier self.interpolator = partial( torch.nn.functional.interpolate, mode=method ) self.remap_output = out_channels is not None if self.remap_output: print( f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.' ) self.channel_mapper = nn.Conv2d( in_channels, out_channels, 1, bias=bias ) def forward(self, x): for stage in range(self.n_stages): x = self.interpolator(x, scale_factor=self.multiplier) if self.remap_output: x = self.channel_mapper(x) return x def encode(self, x): return self(x) class FrozenCLIPEmbedder(AbstractEncoder): """Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__( self, version='openai/clip-vit-large-patch14', device=choose_torch_device(), max_length=77, ): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained( version, local_files_only=False ) self.transformer = CLIPTextModel.from_pretrained( version, local_files_only=False ) self.device = device self.max_length = max_length self.freeze() def embedding_forward( self, input_ids=None, position_ids=None, inputs_embeds=None, embedding_manager=None, ) -> torch.Tensor: seq_length = ( input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] ) if position_ids is None: position_ids = self.position_ids[:, :seq_length] if inputs_embeds is None: inputs_embeds = self.token_embedding(input_ids) if embedding_manager is not None: inputs_embeds = embedding_manager(input_ids, inputs_embeds) position_embeddings = self.position_embedding(position_ids) embeddings = inputs_embeds + position_embeddings return embeddings self.transformer.text_model.embeddings.forward = ( embedding_forward.__get__(self.transformer.text_model.embeddings) ) def encoder_forward( self, inputs_embeds, attention_mask=None, causal_attention_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) layer_outputs = encoder_layer( hidden_states, attention_mask, causal_attention_mask, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) return hidden_states self.transformer.text_model.encoder.forward = encoder_forward.__get__( self.transformer.text_model.encoder ) def text_encoder_forward( self, input_ids=None, attention_mask=None, position_ids=None, output_attentions=None, output_hidden_states=None, return_dict=None, embedding_manager=None, ): output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is None: raise ValueError('You have to specify either input_ids') input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) hidden_states = self.embeddings( input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager, ) bsz, seq_len = input_shape # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 causal_attention_mask = _build_causal_attention_mask( bsz, seq_len, hidden_states.dtype ).to(hidden_states.device) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _expand_mask( attention_mask, hidden_states.dtype ) last_hidden_state = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = self.final_layer_norm(last_hidden_state) return last_hidden_state self.transformer.text_model.forward = text_encoder_forward.__get__( self.transformer.text_model ) def transformer_forward( self, input_ids=None, attention_mask=None, position_ids=None, output_attentions=None, output_hidden_states=None, return_dict=None, embedding_manager=None, ): return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, embedding_manager=embedding_manager, ) self.transformer.forward = transformer_forward.__get__( self.transformer ) def freeze(self): self.transformer = self.transformer.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text, **kwargs): batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=True, return_overflowing_tokens=False, padding='max_length', return_tensors='pt', ) tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) return z def encode(self, text, **kwargs): return self(text, **kwargs) class FrozenCLIPTextEmbedder(nn.Module): """ Uses the CLIP transformer encoder for text. """ def __init__( self, version='ViT-L/14', device=choose_torch_device(), max_length=77, n_repeat=1, normalize=True, ): super().__init__() self.model, _ = clip.load(version, jit=False, device=device) self.device = device self.max_length = max_length self.n_repeat = n_repeat self.normalize = normalize def freeze(self): self.model = self.model.eval() for param in self.parameters(): param.requires_grad = False def forward(self, text): tokens = clip.tokenize(text).to(self.device) z = self.model.encode_text(tokens) if self.normalize: z = z / torch.linalg.norm(z, dim=1, keepdim=True) return z def encode(self, text): z = self(text) if z.ndim == 2: z = z[:, None, :] z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) return z class FrozenClipImageEmbedder(nn.Module): """ Uses the CLIP image encoder. """ def __init__( self, model, jit=False, device=choose_torch_device(), antialias=False, ): super().__init__() self.model, _ = clip.load(name=model, device=device, jit=jit) self.antialias = antialias self.register_buffer( 'mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False, ) self.register_buffer( 'std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False, ) def preprocess(self, x): # normalize to [0,1] x = kornia.geometry.resize( x, (224, 224), interpolation='bicubic', align_corners=True, antialias=self.antialias, ) x = (x + 1.0) / 2.0 # renormalize according to clip x = kornia.enhance.normalize(x, self.mean, self.std) return x def forward(self, x): # x is assumed to be in range [-1,1] return self.model.encode_image(self.preprocess(x)) if __name__ == '__main__': from src.stablediffusion.ldm.util import count_params model = FrozenCLIPEmbedder() count_params(model, verbose=True) ================================================ FILE: src/stablediffusion/ldm/modules/image_degradation/__init__.py ================================================ from src.stablediffusion.ldm.modules.image_degradation.bsrgan import ( degradation_bsrgan_variant as degradation_fn_bsr, ) from src.stablediffusion.ldm.modules.image_degradation.bsrgan_light import ( degradation_bsrgan_variant as degradation_fn_bsr_light, ) ================================================ FILE: src/stablediffusion/ldm/modules/image_degradation/bsrgan.py ================================================ # -*- coding: utf-8 -*- """ # -------------------------------------------- # Super-Resolution # -------------------------------------------- # # Kai Zhang (cskaizhang@gmail.com) # https://github.com/cszn # From 2019/03--2021/08 # -------------------------------------------- """ import numpy as np import cv2 import torch from functools import partial import random from scipy import ndimage import scipy import scipy.stats as ss from scipy.interpolate import interp2d from scipy.linalg import orth import albumentations import ldm.modules.image_degradation.utils_image as util def modcrop_np(img, sf): """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image """ w, h = img.shape[:2] im = np.copy(img) return im[: w - w % sf, : h - h % sf, ...] """ # -------------------------------------------- # anisotropic Gaussian kernels # -------------------------------------------- """ def analytic_kernel(k): """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" k_size = k.shape[0] # Calculate the big kernels size big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += ( k[r, c] * k ) # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] # Normalize to 1 return cropped_big_k / cropped_big_k.sum() def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range l1 : [0.1,50], scaling of eigenvalues l2 : [0.1,l1], scaling of eigenvalues If l1 = l2, will get an isotropic Gaussian kernel. Returns: k : kernel """ v = np.dot( np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ), np.array([1.0, 0.0]), ) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) return k def gm_blur_kernel(mean, cov, size=15): center = size / 2.0 + 0.5 k = np.zeros([size, size]) for y in range(size): for x in range(size): cy = y - center + 1 cx = x - center + 1 k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) k = k / np.sum(k) return k def shift_pixel(x, sf, upper_left=True): """shift pixel for super-resolution with different scale factors Args: x: WxHxC or WxH sf: scale factor upper_left: shift direction """ h, w = x.shape[:2] shift = (sf - 1) * 0.5 xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) if upper_left: x1 = xv + shift y1 = yv + shift else: x1 = xv - shift y1 = yv - shift x1 = np.clip(x1, 0, w - 1) y1 = np.clip(y1, 0, h - 1) if x.ndim == 2: x = interp2d(xv, yv, x)(x1, y1) if x.ndim == 3: for i in range(x.shape[-1]): x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) return x def blur(x, k): """ x: image, NxcxHxW k: kernel, Nx1xhxw """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) x = torch.nn.functional.conv2d( x, k, bias=None, stride=1, padding=0, groups=n * c ) x = x.view(n, c, x.shape[2], x.shape[3]) return x def gen_kernel( k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0, ): """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var # max_var = 2.5 * sf """ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix lambda_1 = min_var + np.random.rand() * (max_var - min_var) lambda_2 = min_var + np.random.rand() * (max_var - min_var) theta = np.random.rand() * np.pi # random theta noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) Q = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) MU = k_size // 2 - 0.5 * ( scale_factor - 1 ) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) Z = np.stack([X, Y], 2)[:, :, :, None] # Calcualte Gaussian for every pixel of the kernel ZZ = Z - MU ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) # shift the kernel so it will be centered # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) # Normalize the kernel and return # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) kernel = raw_kernel / np.sum(raw_kernel) return kernel def fspecial_gaussian(hsize, sigma): hsize = [hsize, hsize] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] std = sigma [x, y] = np.meshgrid( np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1) ) arg = -(x * x + y * y) / (2 * std * std) h = np.exp(arg) h[h < scipy.finfo(float).eps * h.max()] = 0 sumh = h.sum() if sumh != 0: h = h / sumh return h def fspecial_laplacian(alpha): alpha = max([0, min([alpha, 1])]) h1 = alpha / (alpha + 1) h2 = (1 - alpha) / (alpha + 1) h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] h = np.array(h) return h def fspecial(filter_type, *args, **kwargs): """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py """ if filter_type == 'gaussian': return fspecial_gaussian(*args, **kwargs) if filter_type == 'laplacian': return fspecial_laplacian(*args, **kwargs) """ # -------------------------------------------- # degradation models # -------------------------------------------- """ def bicubic_degradation(x, sf=3): """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double sf: down-scale factor Return: downsampled LR image Reference: @inproceedings{zhang2018learning, title={Learning a single convolutional super-resolution network for multiple degradations}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={3262--3271}, year={2018} } """ x = ndimage.filters.convolve( x, np.expand_dims(k, axis=2), mode='wrap' ) # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double sf: down-scale factor Return: downsampled LR image Reference: @inproceedings{zhang2019deep, title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={1671--1681}, year={2019} } """ x = bicubic_degradation(x, sf=sf) x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') return x def classical_degradation(x, k, sf=3): """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image """ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] def add_sharpening(img, weight=0.5, radius=50, threshold=10): """USM sharpening. borrowed from real-ESRGAN Input image: I; Blurry image: B. 1. K = I + weight * (I - B) 2. Mask = 1 if abs(I - B) > threshold, else: 0 3. Blur mask: 4. Out = Mask * K + (1 - Mask) * I Args: img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. weight (float): Sharp weight. Default: 1. radius (float): Kernel size of Gaussian blur. Default: 50. threshold (int): """ if radius % 2 == 0: radius += 1 blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold mask = mask.astype('float32') soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual K = np.clip(K, 0, 1) return soft_mask * K + (1 - soft_mask) * img def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian( ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2, ) else: k = fspecial( 'gaussian', 2 * random.randint(2, 11) + 3, wd * random.random() ) img = ndimage.filters.convolve( img, np.expand_dims(k, axis=2), mode='mirror' ) return img def add_resize(img, sf=4): rnum = np.random.rand() if rnum > 0.8: # up sf1 = random.uniform(1, 2) elif rnum < 0.7: # down sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 img = cv2.resize( img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) img = np.clip(img, 0.0, 1.0) return img # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): # noise_level = random.randint(noise_level1, noise_level2) # rnum = np.random.rand() # if rnum > 0.6: # add color Gaussian noise # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) # elif rnum < 0.4: # add grayscale Gaussian noise # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) # else: # add noise # L = noise_level2 / 255. # D = np.diag(np.random.rand(3)) # U = orth(np.random.rand(3, 3)) # conv = np.dot(np.dot(np.transpose(U), D), U) # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) # img = np.clip(img, 0.0, 1.0) # return img def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( np.float32 ) elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal( 0, noise_level / 255.0, (*img.shape[:2], 1) ).astype(np.float32) else: # add noise L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) img = img + np.random.multivariate_normal( [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_speckle_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: img += img * np.random.normal( 0, noise_level / 255.0, img.shape ).astype(np.float32) elif rnum < 0.4: img += img * np.random.normal( 0, noise_level / 255.0, (*img.shape[:2], 1) ).astype(np.float32) else: L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) img += img * np.random.multivariate_normal( [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = ( np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray ) img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img def add_JPEG_noise(img): quality_factor = random.randint(30, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) result, encimg = cv2.imencode( '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] ) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) hq = hq[ rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :, ] return lq, hq def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): """ This is the degradation model of BSRGAN from the paper "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" ---------- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) sf: scale factor isp_model: camera ISP model Returns ------- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] """ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 sf_ori = sf h1, w1 = img.shape[:2] img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: raise ValueError(f'img size ({h1}X{w1}) is too small!') hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: img = cv2.resize( img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) sf = 2 shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = ( shuffle_order[idx2], shuffle_order[idx1], ) for i in shuffle_order: if i == 0: img = add_blur(img, sf=sf) elif i == 1: img = add_blur(img, sf=sf) elif i == 2: a, b = img.shape[1], img.shape[0] # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) img = cv2.resize( img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = ( k_shifted / k_shifted.sum() ) # blur with shifted kernel img = ndimage.filters.convolve( img, np.expand_dims(k_shifted, axis=2), mode='mirror' ) img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: # downsample3 img = cv2.resize( img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]), ) img = np.clip(img, 0.0, 1.0) elif i == 4: # add Gaussian noise img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) elif i == 5: # add JPEG noise if random.random() < jpeg_prob: img = add_JPEG_noise(img) elif i == 6: # add processed camera sensor noise if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise img = add_JPEG_noise(img) # random crop img, hq = random_crop(img, hq, sf_ori, lq_patchsize) return img, hq # todo no isp_model? def degradation_bsrgan_variant(image, sf=4, isp_model=None): """ This is the degradation model of BSRGAN from the paper "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" ---------- sf: scale factor isp_model: camera ISP model Returns ------- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 sf_ori = sf h1, w1 = image.shape[:2] image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: image = cv2.resize( image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) sf = 2 shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = ( shuffle_order[idx2], shuffle_order[idx1], ) for i in shuffle_order: if i == 0: image = add_blur(image, sf=sf) elif i == 1: image = add_blur(image, sf=sf) elif i == 2: a, b = image.shape[1], image.shape[0] # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) image = cv2.resize( image, ( int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0]), ), interpolation=random.choice([1, 2, 3]), ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = ( k_shifted / k_shifted.sum() ) # blur with shifted kernel image = ndimage.filters.convolve( image, np.expand_dims(k_shifted, axis=2), mode='mirror' ) image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) elif i == 3: # downsample3 image = cv2.resize( image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]), ) image = np.clip(image, 0.0, 1.0) elif i == 4: # add Gaussian noise image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) elif i == 5: # add JPEG noise if random.random() < jpeg_prob: image = add_JPEG_noise(image) # elif i == 6: # # add processed camera sensor noise # if random.random() < isp_prob and isp_model is not None: # with torch.no_grad(): # img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) example = {'image': image} return example # TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... def degradation_bsrgan_plus( img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None, ): """ This is an extended degradation model by combining the degradation models of BSRGAN and Real-ESRGAN ---------- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) sf: scale factor use_shuffle: the degradation shuffle use_sharp: sharpening the img Returns ------- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] """ h1, w1 = img.shape[:2] img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: raise ValueError(f'img size ({h1}X{w1}) is too small!') if use_sharp: img = add_sharpening(img) hq = img.copy() if random.random() < shuffle_prob: shuffle_order = random.sample(range(13), 13) else: shuffle_order = list(range(13)) # local shuffle for noise, JPEG is always the last one shuffle_order[2:6] = random.sample( shuffle_order[2:6], len(range(2, 6)) ) shuffle_order[9:13] = random.sample( shuffle_order[9:13], len(range(9, 13)) ) poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 for i in shuffle_order: if i == 0: img = add_blur(img, sf=sf) elif i == 1: img = add_resize(img, sf=sf) elif i == 2: img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) elif i == 3: if random.random() < poisson_prob: img = add_Poisson_noise(img) elif i == 4: if random.random() < speckle_prob: img = add_speckle_noise(img) elif i == 5: if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) elif i == 6: img = add_JPEG_noise(img) elif i == 7: img = add_blur(img, sf=sf) elif i == 8: img = add_resize(img, sf=sf) elif i == 9: img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) elif i == 10: if random.random() < poisson_prob: img = add_Poisson_noise(img) elif i == 11: if random.random() < speckle_prob: img = add_speckle_noise(img) elif i == 12: if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) else: print('check the shuffle!') # resize to desired size img = cv2.resize( img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), interpolation=random.choice([1, 2, 3]), ) # add final JPEG compression noise img = add_JPEG_noise(img) # random crop img, hq = random_crop(img, hq, sf, lq_patchsize) return img, hq if __name__ == '__main__': print('hey') img = util.imread_uint('utils/test.png', 3) print(img) img = util.uint2single(img) print(img) img = img[:448, :448] h = img.shape[0] // 4 print('resizing to', h) sf = 4 deg_fn = partial(degradation_bsrgan_variant, sf=sf) for i in range(20): print(i) img_lq = deg_fn(img) print(img_lq) img_lq_bicubic = albumentations.SmallestMaxSize( max_size=h, interpolation=cv2.INTER_CUBIC )(image=img)['image'] print(img_lq.shape) print('bicubic', img_lq_bicubic.shape) print(img_hq.shape) lq_nearest = cv2.resize( util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) lq_bicubic_nearest = cv2.resize( util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) img_concat = np.concatenate( [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 ) util.imsave(img_concat, str(i) + '.png') ================================================ FILE: src/stablediffusion/ldm/modules/image_degradation/bsrgan_light.py ================================================ # -*- coding: utf-8 -*- import numpy as np import cv2 import torch from functools import partial import random from scipy import ndimage import scipy import scipy.stats as ss from scipy.interpolate import interp2d from scipy.linalg import orth import albumentations import ldm.modules.image_degradation.utils_image as util """ # -------------------------------------------- # Super-Resolution # -------------------------------------------- # # Kai Zhang (cskaizhang@gmail.com) # https://github.com/cszn # From 2019/03--2021/08 # -------------------------------------------- """ def modcrop_np(img, sf): """ Args: img: numpy image, WxH or WxHxC sf: scale factor Return: cropped image """ w, h = img.shape[:2] im = np.copy(img) return im[: w - w % sf, : h - h % sf, ...] """ # -------------------------------------------- # anisotropic Gaussian kernels # -------------------------------------------- """ def analytic_kernel(k): """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" k_size = k.shape[0] # Calculate the big kernels size big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) # Loop over the small kernel to fill the big one for r in range(k_size): for c in range(k_size): big_k[2 * r : 2 * r + k_size, 2 * c : 2 * c + k_size] += ( k[r, c] * k ) # Crop the edges of the big kernel to ignore very small values and increase run time of SR crop = k_size // 2 cropped_big_k = big_k[crop:-crop, crop:-crop] # Normalize to 1 return cropped_big_k / cropped_big_k.sum() def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): """generate an anisotropic Gaussian kernel Args: ksize : e.g., 15, kernel size theta : [0, pi], rotation angle range l1 : [0.1,50], scaling of eigenvalues l2 : [0.1,l1], scaling of eigenvalues If l1 = l2, will get an isotropic Gaussian kernel. Returns: k : kernel """ v = np.dot( np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ), np.array([1.0, 0.0]), ) V = np.array([[v[0], v[1]], [v[1], -v[0]]]) D = np.array([[l1, 0], [0, l2]]) Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) return k def gm_blur_kernel(mean, cov, size=15): center = size / 2.0 + 0.5 k = np.zeros([size, size]) for y in range(size): for x in range(size): cy = y - center + 1 cx = x - center + 1 k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) k = k / np.sum(k) return k def shift_pixel(x, sf, upper_left=True): """shift pixel for super-resolution with different scale factors Args: x: WxHxC or WxH sf: scale factor upper_left: shift direction """ h, w = x.shape[:2] shift = (sf - 1) * 0.5 xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) if upper_left: x1 = xv + shift y1 = yv + shift else: x1 = xv - shift y1 = yv - shift x1 = np.clip(x1, 0, w - 1) y1 = np.clip(y1, 0, h - 1) if x.ndim == 2: x = interp2d(xv, yv, x)(x1, y1) if x.ndim == 3: for i in range(x.shape[-1]): x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) return x def blur(x, k): """ x: image, NxcxHxW k: kernel, Nx1xhxw """ n, c = x.shape[:2] p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') k = k.repeat(1, c, 1, 1) k = k.view(-1, 1, k.shape[2], k.shape[3]) x = x.view(1, -1, x.shape[2], x.shape[3]) x = torch.nn.functional.conv2d( x, k, bias=None, stride=1, padding=0, groups=n * c ) x = x.view(n, c, x.shape[2], x.shape[3]) return x def gen_kernel( k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10.0, noise_level=0, ): """ " # modified version of https://github.com/assafshocher/BlindSR_dataset_generator # Kai Zhang # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var # max_var = 2.5 * sf """ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix lambda_1 = min_var + np.random.rand() * (max_var - min_var) lambda_2 = min_var + np.random.rand() * (max_var - min_var) theta = np.random.rand() * np.pi # random theta noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 # Set COV matrix using Lambdas and Theta LAMBDA = np.diag([lambda_1, lambda_2]) Q = np.array( [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] ) SIGMA = Q @ LAMBDA @ Q.T INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] # Set expectation position (shifting kernel for aligned image) MU = k_size // 2 - 0.5 * ( scale_factor - 1 ) # - 0.5 * (scale_factor - k_size % 2) MU = MU[None, None, :, None] # Create meshgrid for Gaussian [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) Z = np.stack([X, Y], 2)[:, :, :, None] # Calcualte Gaussian for every pixel of the kernel ZZ = Z - MU ZZ_t = ZZ.transpose(0, 1, 3, 2) raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) # shift the kernel so it will be centered # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) # Normalize the kernel and return # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) kernel = raw_kernel / np.sum(raw_kernel) return kernel def fspecial_gaussian(hsize, sigma): hsize = [hsize, hsize] siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] std = sigma [x, y] = np.meshgrid( np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1) ) arg = -(x * x + y * y) / (2 * std * std) h = np.exp(arg) h[h < scipy.finfo(float).eps * h.max()] = 0 sumh = h.sum() if sumh != 0: h = h / sumh return h def fspecial_laplacian(alpha): alpha = max([0, min([alpha, 1])]) h1 = alpha / (alpha + 1) h2 = (1 - alpha) / (alpha + 1) h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] h = np.array(h) return h def fspecial(filter_type, *args, **kwargs): """ python code from: https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py """ if filter_type == 'gaussian': return fspecial_gaussian(*args, **kwargs) if filter_type == 'laplacian': return fspecial_laplacian(*args, **kwargs) """ # -------------------------------------------- # degradation models # -------------------------------------------- """ def bicubic_degradation(x, sf=3): """ Args: x: HxWxC image, [0, 1] sf: down-scale factor Return: bicubicly downsampled LR image """ x = util.imresize_np(x, scale=1 / sf) return x def srmd_degradation(x, k, sf=3): """blur + bicubic downsampling Args: x: HxWxC image, [0, 1] k: hxw, double sf: down-scale factor Return: downsampled LR image Reference: @inproceedings{zhang2018learning, title={Learning a single convolutional super-resolution network for multiple degradations}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={3262--3271}, year={2018} } """ x = ndimage.filters.convolve( x, np.expand_dims(k, axis=2), mode='wrap' ) # 'nearest' | 'mirror' x = bicubic_degradation(x, sf=sf) return x def dpsr_degradation(x, k, sf=3): """bicubic downsampling + blur Args: x: HxWxC image, [0, 1] k: hxw, double sf: down-scale factor Return: downsampled LR image Reference: @inproceedings{zhang2019deep, title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, pages={1671--1681}, year={2019} } """ x = bicubic_degradation(x, sf=sf) x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') return x def classical_degradation(x, k, sf=3): """blur + downsampling Args: x: HxWxC image, [0, 1]/[0, 255] k: hxw, double sf: down-scale factor Return: downsampled LR image """ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) st = 0 return x[st::sf, st::sf, ...] def add_sharpening(img, weight=0.5, radius=50, threshold=10): """USM sharpening. borrowed from real-ESRGAN Input image: I; Blurry image: B. 1. K = I + weight * (I - B) 2. Mask = 1 if abs(I - B) > threshold, else: 0 3. Blur mask: 4. Out = Mask * K + (1 - Mask) * I Args: img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. weight (float): Sharp weight. Default: 1. radius (float): Kernel size of Gaussian blur. Default: 50. threshold (int): """ if radius % 2 == 0: radius += 1 blur = cv2.GaussianBlur(img, (radius, radius), 0) residual = img - blur mask = np.abs(residual) * 255 > threshold mask = mask.astype('float32') soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) K = img + weight * residual K = np.clip(K, 0, 1) return soft_mask * K + (1 - soft_mask) * img def add_blur(img, sf=4): wd2 = 4.0 + sf wd = 2.0 + 0.2 * sf wd2 = wd2 / 4 wd = wd / 4 if random.random() < 0.5: l1 = wd2 * random.random() l2 = wd2 * random.random() k = anisotropic_Gaussian( ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2, ) else: k = fspecial( 'gaussian', random.randint(2, 4) + 3, wd * random.random() ) img = ndimage.filters.convolve( img, np.expand_dims(k, axis=2), mode='mirror' ) return img def add_resize(img, sf=4): rnum = np.random.rand() if rnum > 0.8: # up sf1 = random.uniform(1, 2) elif rnum < 0.7: # down sf1 = random.uniform(0.5 / sf, 1) else: sf1 = 1.0 img = cv2.resize( img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) img = np.clip(img, 0.0, 1.0) return img # def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): # noise_level = random.randint(noise_level1, noise_level2) # rnum = np.random.rand() # if rnum > 0.6: # add color Gaussian noise # img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) # elif rnum < 0.4: # add grayscale Gaussian noise # img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) # else: # add noise # L = noise_level2 / 255. # D = np.diag(np.random.rand(3)) # U = orth(np.random.rand(3, 3)) # conv = np.dot(np.dot(np.transpose(U), D), U) # img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) # img = np.clip(img, 0.0, 1.0) # return img def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) rnum = np.random.rand() if rnum > 0.6: # add color Gaussian noise img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype( np.float32 ) elif rnum < 0.4: # add grayscale Gaussian noise img = img + np.random.normal( 0, noise_level / 255.0, (*img.shape[:2], 1) ).astype(np.float32) else: # add noise L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) img = img + np.random.multivariate_normal( [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_speckle_noise(img, noise_level1=2, noise_level2=25): noise_level = random.randint(noise_level1, noise_level2) img = np.clip(img, 0.0, 1.0) rnum = random.random() if rnum > 0.6: img += img * np.random.normal( 0, noise_level / 255.0, img.shape ).astype(np.float32) elif rnum < 0.4: img += img * np.random.normal( 0, noise_level / 255.0, (*img.shape[:2], 1) ).astype(np.float32) else: L = noise_level2 / 255.0 D = np.diag(np.random.rand(3)) U = orth(np.random.rand(3, 3)) conv = np.dot(np.dot(np.transpose(U), D), U) img += img * np.random.multivariate_normal( [0, 0, 0], np.abs(L**2 * conv), img.shape[:2] ).astype(np.float32) img = np.clip(img, 0.0, 1.0) return img def add_Poisson_noise(img): img = np.clip((img * 255.0).round(), 0, 255) / 255.0 vals = 10 ** (2 * random.random() + 2.0) # [2, 4] if random.random() < 0.5: img = np.random.poisson(img * vals).astype(np.float32) / vals else: img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.0 noise_gray = ( np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray ) img += noise_gray[:, :, np.newaxis] img = np.clip(img, 0.0, 1.0) return img def add_JPEG_noise(img): quality_factor = random.randint(80, 95) img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) result, encimg = cv2.imencode( '.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor] ) img = cv2.imdecode(encimg, 1) img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) return img def random_crop(lq, hq, sf=4, lq_patchsize=64): h, w = lq.shape[:2] rnd_h = random.randint(0, h - lq_patchsize) rnd_w = random.randint(0, w - lq_patchsize) lq = lq[rnd_h : rnd_h + lq_patchsize, rnd_w : rnd_w + lq_patchsize, :] rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) hq = hq[ rnd_h_H : rnd_h_H + lq_patchsize * sf, rnd_w_H : rnd_w_H + lq_patchsize * sf, :, ] return lq, hq def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): """ This is the degradation model of BSRGAN from the paper "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" ---------- img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) sf: scale factor isp_model: camera ISP model Returns ------- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] """ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 sf_ori = sf h1, w1 = img.shape[:2] img = img.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = img.shape[:2] if h < lq_patchsize * sf or w < lq_patchsize * sf: raise ValueError(f'img size ({h1}X{w1}) is too small!') hq = img.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: img = cv2.resize( img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: img = util.imresize_np(img, 1 / 2, True) img = np.clip(img, 0.0, 1.0) sf = 2 shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = ( shuffle_order[idx2], shuffle_order[idx1], ) for i in shuffle_order: if i == 0: img = add_blur(img, sf=sf) elif i == 1: img = add_blur(img, sf=sf) elif i == 2: a, b = img.shape[1], img.shape[0] # downsample2 if random.random() < 0.75: sf1 = random.uniform(1, 2 * sf) img = cv2.resize( img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = ( k_shifted / k_shifted.sum() ) # blur with shifted kernel img = ndimage.filters.convolve( img, np.expand_dims(k_shifted, axis=2), mode='mirror' ) img = img[0::sf, 0::sf, ...] # nearest downsampling img = np.clip(img, 0.0, 1.0) elif i == 3: # downsample3 img = cv2.resize( img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]), ) img = np.clip(img, 0.0, 1.0) elif i == 4: # add Gaussian noise img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) elif i == 5: # add JPEG noise if random.random() < jpeg_prob: img = add_JPEG_noise(img) elif i == 6: # add processed camera sensor noise if random.random() < isp_prob and isp_model is not None: with torch.no_grad(): img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise img = add_JPEG_noise(img) # random crop img, hq = random_crop(img, hq, sf_ori, lq_patchsize) return img, hq # todo no isp_model? def degradation_bsrgan_variant(image, sf=4, isp_model=None): """ This is the degradation model of BSRGAN from the paper "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" ---------- sf: scale factor isp_model: camera ISP model Returns ------- img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] """ image = util.uint2single(image) isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 sf_ori = sf h1, w1 = image.shape[:2] image = image.copy()[: w1 - w1 % sf, : h1 - h1 % sf, ...] # mod crop h, w = image.shape[:2] hq = image.copy() if sf == 4 and random.random() < scale2_prob: # downsample1 if np.random.rand() < 0.5: image = cv2.resize( image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), interpolation=random.choice([1, 2, 3]), ) else: image = util.imresize_np(image, 1 / 2, True) image = np.clip(image, 0.0, 1.0) sf = 2 shuffle_order = random.sample(range(7), 7) idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) if idx1 > idx2: # keep downsample3 last shuffle_order[idx1], shuffle_order[idx2] = ( shuffle_order[idx2], shuffle_order[idx1], ) for i in shuffle_order: if i == 0: image = add_blur(image, sf=sf) # elif i == 1: # image = add_blur(image, sf=sf) if i == 0: pass elif i == 2: a, b = image.shape[1], image.shape[0] # downsample2 if random.random() < 0.8: sf1 = random.uniform(1, 2 * sf) image = cv2.resize( image, ( int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0]), ), interpolation=random.choice([1, 2, 3]), ) else: k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) k_shifted = shift_pixel(k, sf) k_shifted = ( k_shifted / k_shifted.sum() ) # blur with shifted kernel image = ndimage.filters.convolve( image, np.expand_dims(k_shifted, axis=2), mode='mirror' ) image = image[0::sf, 0::sf, ...] # nearest downsampling image = np.clip(image, 0.0, 1.0) elif i == 3: # downsample3 image = cv2.resize( image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]), ) image = np.clip(image, 0.0, 1.0) elif i == 4: # add Gaussian noise image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) elif i == 5: # add JPEG noise if random.random() < jpeg_prob: image = add_JPEG_noise(image) # # elif i == 6: # # add processed camera sensor noise # if random.random() < isp_prob and isp_model is not None: # with torch.no_grad(): # img, hq = isp_model.forward(img.copy(), hq) # add final JPEG compression noise image = add_JPEG_noise(image) image = util.single2uint(image) example = {'image': image} return example if __name__ == '__main__': print('hey') img = util.imread_uint('utils/test.png', 3) img = img[:448, :448] h = img.shape[0] // 4 print('resizing to', h) sf = 4 deg_fn = partial(degradation_bsrgan_variant, sf=sf) for i in range(20): print(i) img_hq = img img_lq = deg_fn(img)['image'] img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) print(img_lq) img_lq_bicubic = albumentations.SmallestMaxSize( max_size=h, interpolation=cv2.INTER_CUBIC )(image=img_hq)['image'] print(img_lq.shape) print('bicubic', img_lq_bicubic.shape) print(img_hq.shape) lq_nearest = cv2.resize( util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) lq_bicubic_nearest = cv2.resize( util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), interpolation=0, ) img_concat = np.concatenate( [lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1 ) util.imsave(img_concat, str(i) + '.png') ================================================ FILE: src/stablediffusion/ldm/modules/image_degradation/utils_image.py ================================================ import os import math import random import numpy as np import torch import cv2 from torchvision.utils import make_grid from datetime import datetime # import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' """ # -------------------------------------------- # Kai Zhang (github: https://github.com/cszn) # 03/Mar/2019 # -------------------------------------------- # https://github.com/twhui/SRGAN-pyTorch # https://github.com/xinntao/BasicSR # -------------------------------------------- """ IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def get_timestamp(): return datetime.now().strftime('%y%m%d-%H%M%S') def imshow(x, title=None, cbar=False, figsize=None): plt.figure(figsize=figsize) plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') if title: plt.title(title) if cbar: plt.colorbar() plt.show() def surf(Z, cmap='rainbow', figsize=None): plt.figure(figsize=figsize) ax3 = plt.axes(projection='3d') w, h = Z.shape[:2] xx = np.arange(0, w, 1) yy = np.arange(0, h, 1) X, Y = np.meshgrid(xx, yy) ax3.plot_surface(X, Y, Z, cmap=cmap) # ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) plt.show() """ # -------------------------------------------- # get image pathes # -------------------------------------------- """ def get_image_paths(dataroot): paths = None # return None if dataroot is None if dataroot is not None: paths = sorted(_get_paths_from_images(dataroot)) return paths def _get_paths_from_images(path): assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) assert images, '{:s} has no valid image file'.format(path) return images """ # -------------------------------------------- # split large images into small images # -------------------------------------------- """ def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): w, h = img.shape[:2] patches = [] if w > p_max and h > p_max: w1 = list(np.arange(0, w - p_size, p_size - p_overlap, dtype=np.int)) h1 = list(np.arange(0, h - p_size, p_size - p_overlap, dtype=np.int)) w1.append(w - p_size) h1.append(h - p_size) # print(w1) # print(h1) for i in w1: for j in h1: patches.append(img[i : i + p_size, j : j + p_size, :]) else: patches.append(img) return patches def imssave(imgs, img_path): """ imgs: list, N images of size WxHxC """ img_name, ext = os.path.splitext(os.path.basename(img_path)) for i, img in enumerate(imgs): if img.ndim == 3: img = img[:, :, [2, 1, 0]] new_path = os.path.join( os.path.dirname(img_path), img_name + str('_s{:04d}'.format(i)) + '.png', ) cv2.imwrite(new_path, img) def split_imageset( original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000, ): """ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) will be splitted. Args: original_dataroot: taget_dataroot: p_size: size of small images p_overlap: patch size in training is a good choice p_max: images with smaller size than (p_max)x(p_max) keep unchanged. """ paths = get_image_paths(original_dataroot) for img_path in paths: # img_name, ext = os.path.splitext(os.path.basename(img_path)) img = imread_uint(img_path, n_channels=n_channels) patches = patches_from_image(img, p_size, p_overlap, p_max) imssave( patches, os.path.join(taget_dataroot, os.path.basename(img_path)) ) # if original_dataroot == taget_dataroot: # del img_path """ # -------------------------------------------- # makedir # -------------------------------------------- """ def mkdir(path): if not os.path.exists(path): os.makedirs(path) def mkdirs(paths): if isinstance(paths, str): mkdir(paths) else: for path in paths: mkdir(path) def mkdir_and_rename(path): if os.path.exists(path): new_name = path + '_archived_' + get_timestamp() print('Path already exists. Rename it to [{:s}]'.format(new_name)) os.rename(path, new_name) os.makedirs(path) """ # -------------------------------------------- # read image from path # opencv is fast, but read BGR numpy image # -------------------------------------------- """ # -------------------------------------------- # get uint8 image of size HxWxn_channles (RGB) # -------------------------------------------- def imread_uint(path, n_channels=3): # input: path # output: HxWx3(RGB or GGG), or HxWx1 (G) if n_channels == 1: img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE img = np.expand_dims(img, axis=2) # HxWx1 elif n_channels == 3: img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G if img.ndim == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG else: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB return img # -------------------------------------------- # matlab's imwrite # -------------------------------------------- def imsave(img, img_path): img = np.squeeze(img) if img.ndim == 3: img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) def imwrite(img, img_path): img = np.squeeze(img) if img.ndim == 3: img = img[:, :, [2, 1, 0]] cv2.imwrite(img_path, img) # -------------------------------------------- # get single image of size HxWxn_channles (BGR) # -------------------------------------------- def read_img(path): # read image by cv2 # return: Numpy float32, HWC, BGR, [0,1] img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels if img.shape[2] > 3: img = img[:, :, :3] return img """ # -------------------------------------------- # image format conversion # -------------------------------------------- # numpy(single) <---> numpy(unit) # numpy(single) <---> tensor # numpy(unit) <---> tensor # -------------------------------------------- """ # -------------------------------------------- # numpy(single) [0, 1] <---> numpy(unit) # -------------------------------------------- def uint2single(img): return np.float32(img / 255.0) def single2uint(img): return np.uint8((img.clip(0, 1) * 255.0).round()) def uint162single(img): return np.float32(img / 65535.0) def single2uint16(img): return np.uint16((img.clip(0, 1) * 65535.0).round()) # -------------------------------------------- # numpy(unit) (HxWxC or HxW) <---> tensor # -------------------------------------------- # convert uint to 4-dimensional torch tensor def uint2tensor4(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) return ( torch.from_numpy(np.ascontiguousarray(img)) .permute(2, 0, 1) .float() .div(255.0) .unsqueeze(0) ) # convert uint to 3-dimensional torch tensor def uint2tensor3(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) return ( torch.from_numpy(np.ascontiguousarray(img)) .permute(2, 0, 1) .float() .div(255.0) ) # convert 2/3/4-dimensional torch tensor to uint def tensor2uint(img): img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) return np.uint8((img * 255.0).round()) # -------------------------------------------- # numpy(single) (HxWxC) <---> tensor # -------------------------------------------- # convert single (HxWxC) to 3-dimensional torch tensor def single2tensor3(img): return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() # convert single (HxWxC) to 4-dimensional torch tensor def single2tensor4(img): return ( torch.from_numpy(np.ascontiguousarray(img)) .permute(2, 0, 1) .float() .unsqueeze(0) ) # convert torch tensor to single def tensor2single(img): img = img.data.squeeze().float().cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) return img # convert torch tensor to single def tensor2single3(img): img = img.data.squeeze().float().cpu().numpy() if img.ndim == 3: img = np.transpose(img, (1, 2, 0)) elif img.ndim == 2: img = np.expand_dims(img, axis=2) return img def single2tensor5(img): return ( torch.from_numpy(np.ascontiguousarray(img)) .permute(2, 0, 1, 3) .float() .unsqueeze(0) ) def single32tensor5(img): return ( torch.from_numpy(np.ascontiguousarray(img)) .float() .unsqueeze(0) .unsqueeze(0) ) def single42tensor4(img): return ( torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() ) # from skimage.io import imread, imsave def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): """ Converts a torch Tensor into an image Numpy array of BGR channel order Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) """ tensor = ( tensor.squeeze().float().cpu().clamp_(*min_max) ) # squeeze first, then clamp tensor = (tensor - min_max[0]) / ( min_max[1] - min_max[0] ) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) img_np = make_grid( tensor, nrow=int(math.sqrt(n_img)), normalize=False ).numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 2: img_np = tensor.numpy() else: raise TypeError( 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format( n_dim ) ) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) """ # -------------------------------------------- # Augmentation, flipe and/or rotate # -------------------------------------------- # The following two are enough. # (1) augmet_img: numpy image of WxHxC or WxH # (2) augment_img_tensor4: tensor image 1xCxWxH # -------------------------------------------- """ def augment_img(img, mode=0): """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: return np.flipud(np.rot90(img)) elif mode == 2: return np.flipud(img) elif mode == 3: return np.rot90(img, k=3) elif mode == 4: return np.flipud(np.rot90(img, k=2)) elif mode == 5: return np.rot90(img) elif mode == 6: return np.rot90(img, k=2) elif mode == 7: return np.flipud(np.rot90(img, k=3)) def augment_img_tensor4(img, mode=0): """Kai Zhang (github: https://github.com/cszn)""" if mode == 0: return img elif mode == 1: return img.rot90(1, [2, 3]).flip([2]) elif mode == 2: return img.flip([2]) elif mode == 3: return img.rot90(3, [2, 3]) elif mode == 4: return img.rot90(2, [2, 3]).flip([2]) elif mode == 5: return img.rot90(1, [2, 3]) elif mode == 6: return img.rot90(2, [2, 3]) elif mode == 7: return img.rot90(3, [2, 3]).flip([2]) def augment_img_tensor(img, mode=0): """Kai Zhang (github: https://github.com/cszn)""" img_size = img.size() img_np = img.data.cpu().numpy() if len(img_size) == 3: img_np = np.transpose(img_np, (1, 2, 0)) elif len(img_size) == 4: img_np = np.transpose(img_np, (2, 3, 1, 0)) img_np = augment_img(img_np, mode=mode) img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) if len(img_size) == 3: img_tensor = img_tensor.permute(2, 0, 1) elif len(img_size) == 4: img_tensor = img_tensor.permute(3, 2, 0, 1) return img_tensor.type_as(img) def augment_img_np3(img, mode=0): if mode == 0: return img elif mode == 1: return img.transpose(1, 0, 2) elif mode == 2: return img[::-1, :, :] elif mode == 3: img = img[::-1, :, :] img = img.transpose(1, 0, 2) return img elif mode == 4: return img[:, ::-1, :] elif mode == 5: img = img[:, ::-1, :] img = img.transpose(1, 0, 2) return img elif mode == 6: img = img[:, ::-1, :] img = img[::-1, :, :] return img elif mode == 7: img = img[:, ::-1, :] img = img[::-1, :, :] img = img.transpose(1, 0, 2) return img def augment_imgs(img_list, hflip=True, rot=True): # horizontal flip OR rotate hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 def _augment(img): if hflip: img = img[:, ::-1, :] if vflip: img = img[::-1, :, :] if rot90: img = img.transpose(1, 0, 2) return img return [_augment(img) for img in img_list] """ # -------------------------------------------- # modcrop and shave # -------------------------------------------- """ def modcrop(img_in, scale): # img_in: Numpy, HWC or HW img = np.copy(img_in) if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale img = img[: H - H_r, : W - W_r, :] else: raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) return img def shave(img_in, border=0): # img_in: Numpy, HWC or HW img = np.copy(img_in) h, w = img.shape[:2] img = img[border : h - border, border : w - border] return img """ # -------------------------------------------- # image processing process on numpy image # channel_convert(in_c, tar_type, img_list): # rgb2ycbcr(img, only_y=True): # bgr2ycbcr(img, only_y=True): # ycbcr2rgb(img): # -------------------------------------------- """ def rgb2ycbcr(img, only_y=True): """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: rlt = np.matmul( img, [ [65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214], ], ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert rlt = np.matmul( img, [ [0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0], ], ) * 255.0 + [-222.921, 135.576, -276.836] if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: rlt = np.matmul( img, [ [24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0], ], ) / 255.0 + [16, 128, 128] if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def channel_convert(in_c, tar_type, img_list): # conversion among BGR, gray and y if in_c == 3 and tar_type == 'gray': # BGR to gray gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] return [np.expand_dims(img, axis=2) for img in gray_list] elif in_c == 3 and tar_type == 'y': # BGR to y y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] return [np.expand_dims(img, axis=2) for img in y_list] elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] else: return img_list """ # -------------------------------------------- # metric, PSNR and SSIM # -------------------------------------------- """ # -------------------------------------------- # PSNR # -------------------------------------------- def calculate_psnr(img1, img2, border=0): # img1 and img2 have range [0, 255] # img1 = img1.squeeze() # img2 = img2.squeeze() if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border : h - border, border : w - border] img2 = img2[border : h - border, border : w - border] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2) ** 2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse)) # -------------------------------------------- # SSIM # -------------------------------------------- def calculate_ssim(img1, img2, border=0): """calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] """ # img1 = img1.squeeze() # img2 = img2.squeeze() if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1 = img1[border : h - border, border : w - border] img2 = img2[border : h - border, border : w - border] if img1.ndim == 2: return ssim(img1, img2) elif img1.ndim == 3: if img1.shape[2] == 3: ssims = [] for i in range(3): ssims.append(ssim(img1[:, :, i], img2[:, :, i])) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: raise ValueError('Wrong input image dimensions.') def ssim(img1, img2): C1 = (0.01 * 255) ** 2 C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1**2 mu2_sq = mu2**2 mu1_mu2 = mu1 * mu2 sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) ) return ssim_map.mean() """ # -------------------------------------------- # matlab's bicubic imresize (numpy and torch) [0, 1] # -------------------------------------------- """ # matlab 'imresize' function, now only support 'bicubic' def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 ) * (((absx > 1) * (absx <= 2)).type_as(absx)) def calculate_weights_indices( in_length, out_length, scale, kernel, kernel_width, antialiasing ): if (scale < 1) and (antialiasing): # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width kernel_width = kernel_width / scale # Output-space coordinates x = torch.linspace(1, out_length, out_length) # Input-space coordinates. Calculate the inverse mapping such that 0.5 # in output space maps to 0.5 in input space, and 0.5+scale in output # space maps to 1.5 in input space. u = x / scale + 0.5 * (1 - 1 / scale) # What is the left-most pixel that can be involved in the computation? left = torch.floor(u - kernel_width / 2) # What is the maximum number of pixels that can be involved in the # computation? Note: it's OK to use an extra pixel here; if the # corresponding weights are all zero, it will be eliminated at the end # of this function. P = math.ceil(kernel_width) + 2 # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( 0, P - 1, P ).view(1, P).expand(out_length, P) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices # apply cubic kernel if (scale < 1) and (antialiasing): weights = scale * cubic(distance_to_center * scale) else: weights = cubic(distance_to_center) # Normalize the weights matrix so that each row sums to 1. weights_sum = torch.sum(weights, 1).view(out_length, 1) weights = weights / weights_sum.expand(out_length, P) # If a column in weights is all zero, get rid of it. only consider the first and last column. weights_zero_tmp = torch.sum((weights == 0), 0) if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): indices = indices.narrow(1, 1, P - 2) weights = weights.narrow(1, 1, P - 2) if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): indices = indices.narrow(1, 0, P - 2) weights = weights.narrow(1, 0, P - 2) weights = weights.contiguous() indices = indices.contiguous() sym_len_s = -indices.min() + 1 sym_len_e = indices.max() - in_length indices = indices + sym_len_s - 1 return weights, indices, int(sym_len_s), int(sym_len_e) # -------------------------------------------- # imresize for tensor image [0, 1] # -------------------------------------------- def imresize(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: pytorch tensor, CHW or HW [0,1] # output: CHW or HW [0,1] w/o round need_squeeze = True if img.dim() == 2 else False if need_squeeze: img.unsqueeze_(0) in_C, in_H, in_W = img.size() out_C, out_H, out_W = ( in_C, math.ceil(in_H * scale), math.ceil(in_W * scale), ) kernel_width = 4 kernel = 'cubic' # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the # smallest scale factor. # Now we do not support this. # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( in_H, out_H, scale, kernel, kernel_width, antialiasing ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( in_W, out_W, scale, kernel, kernel_width, antialiasing ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) sym_patch = img[:, :sym_len_Hs, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) sym_patch = img[:, -sym_len_He:, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) out_1 = torch.FloatTensor(in_C, out_H, in_W) kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): out_1[j, i, :] = ( img_aug[j, idx : idx + kernel_width, :] .transpose(0, 1) .mv(weights_H[i]) ) # process W dimension # symmetric copying out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) sym_patch = out_1[:, :, :sym_len_Ws] inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(2, inv_idx) out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) sym_patch = out_1[:, :, -sym_len_We:] inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(2, inv_idx) out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) out_2 = torch.FloatTensor(in_C, out_H, out_W) kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv( weights_W[i] ) if need_squeeze: out_2.squeeze_() return out_2 # -------------------------------------------- # imresize for numpy image [0, 1] # -------------------------------------------- def imresize_np(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: Numpy, HWC or HW [0,1] # output: HWC or HW [0,1] w/o round img = torch.from_numpy(img) need_squeeze = True if img.dim() == 2 else False if need_squeeze: img.unsqueeze_(2) in_H, in_W, in_C = img.size() out_C, out_H, out_W = ( in_C, math.ceil(in_H * scale), math.ceil(in_W * scale), ) kernel_width = 4 kernel = 'cubic' # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the # smallest scale factor. # Now we do not support this. # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( in_H, out_H, scale, kernel, kernel_width, antialiasing ) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( in_W, out_W, scale, kernel, kernel_width, antialiasing ) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) sym_patch = img[:sym_len_Hs, :, :] inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(0, inv_idx) img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) sym_patch = img[-sym_len_He:, :, :] inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(0, inv_idx) img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) out_1 = torch.FloatTensor(out_H, in_W, in_C) kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): out_1[i, :, j] = ( img_aug[idx : idx + kernel_width, :, j] .transpose(0, 1) .mv(weights_H[i]) ) # process W dimension # symmetric copying out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) sym_patch = out_1[:, :sym_len_Ws, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) sym_patch = out_1[:, -sym_len_We:, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) out_2 = torch.FloatTensor(out_H, out_W, in_C) kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): out_2[:, i, j] = out_1_aug[:, idx : idx + kernel_width, j].mv( weights_W[i] ) if need_squeeze: out_2.squeeze_() return out_2.numpy() if __name__ == '__main__': print('---') # img = imread_uint('test.bmp', 3) # img = uint2single(img) # img_bicubic = imresize_np(img, 1/4) ================================================ FILE: src/stablediffusion/ldm/modules/losses/__init__.py ================================================ from src.stablediffusion.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator ================================================ FILE: src/stablediffusion/ldm/modules/losses/contperceptual.py ================================================ import torch import torch.nn as nn from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? class LPIPSWithDiscriminator(nn.Module): def __init__( self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_loss='hinge', ): super().__init__() assert disc_loss in ['hinge', 'vanilla'] self.kl_weight = kl_weight self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight # output log variance self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) self.discriminator = NLayerDiscriminator( input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ).apply(weights_init) self.discriminator_iter_start = disc_start self.disc_loss = ( hinge_d_loss if disc_loss == 'hinge' else vanilla_d_loss ) self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: nll_grads = torch.autograd.grad( nll_loss, last_layer, retain_graph=True )[0] g_grads = torch.autograd.grad( g_loss, last_layer, retain_graph=True )[0] else: nll_grads = torch.autograd.grad( nll_loss, self.last_layer[0], retain_graph=True )[0] g_grads = torch.autograd.grad( g_loss, self.last_layer[0], retain_graph=True )[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, split='train', weights=None, ): rec_loss = torch.abs( inputs.contiguous() - reconstructions.contiguous() ) if self.perceptual_weight > 0: p_loss = self.perceptual_loss( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss weighted_nll_loss = ( torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] ) nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] # now the GAN part if optimizer_idx == 0: # generator update if cond is None: assert not self.disc_conditional logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional logits_fake = self.discriminator( torch.cat((reconstructions.contiguous(), cond), dim=1) ) g_loss = -torch.mean(logits_fake) if self.disc_factor > 0.0: try: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) else: d_weight = torch.tensor(0.0) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start, ) loss = ( weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss ) log = { '{}/total_loss'.format(split): loss.clone().detach().mean(), '{}/logvar'.format(split): self.logvar.detach(), '{}/kl_loss'.format(split): kl_loss.detach().mean(), '{}/nll_loss'.format(split): nll_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(), '{}/d_weight'.format(split): d_weight.detach(), '{}/disc_factor'.format(split): torch.tensor(disc_factor), '{}/g_loss'.format(split): g_loss.detach().mean(), } return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator( reconstructions.contiguous().detach() ) else: logits_real = self.discriminator( torch.cat((inputs.contiguous().detach(), cond), dim=1) ) logits_fake = self.discriminator( torch.cat( (reconstructions.contiguous().detach(), cond), dim=1 ) ) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start, ) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = { '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), '{}/logits_real'.format(split): logits_real.detach().mean(), '{}/logits_fake'.format(split): logits_fake.detach().mean(), } return d_loss, log ================================================ FILE: src/stablediffusion/ldm/modules/losses/vqperceptual.py ================================================ import torch from torch import nn import torch.nn.functional as F from einops import repeat from taming.modules.discriminator.model import ( NLayerDiscriminator, weights_init, ) from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] loss_real = torch.mean(F.relu(1.0 - logits_real), dim=[1, 2, 3]) loss_fake = torch.mean(F.relu(1.0 + logits_fake), dim=[1, 2, 3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value return weight def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = ( F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) ) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): return torch.abs(x - y) def l2(x, y): return torch.pow((x - y), 2) class VQLPIPSWithDiscriminator(nn.Module): def __init__( self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_ndf=64, disc_loss='hinge', n_classes=None, perceptual_loss='lpips', pixel_loss='l1', ): super().__init__() assert disc_loss in ['hinge', 'vanilla'] assert perceptual_loss in ['lpips', 'clips', 'dists'] assert pixel_loss in ['l1', 'l2'] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight if perceptual_loss == 'lpips': print(f'{self.__class__.__name__}: Running with LPIPS.') self.perceptual_loss = LPIPS().eval() else: raise ValueError( f'Unknown perceptual loss: >> {perceptual_loss} <<' ) self.perceptual_weight = perceptual_weight if pixel_loss == 'l1': self.pixel_loss = l1 else: self.pixel_loss = l2 self.discriminator = NLayerDiscriminator( input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf, ).apply(weights_init) self.discriminator_iter_start = disc_start if disc_loss == 'hinge': self.disc_loss = hinge_d_loss elif disc_loss == 'vanilla': self.disc_loss = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{disc_loss}'.") print(f'VQLPIPSWithDiscriminator running with {disc_loss} loss.') self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional self.n_classes = n_classes def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: nll_grads = torch.autograd.grad( nll_loss, last_layer, retain_graph=True )[0] g_grads = torch.autograd.grad( g_loss, last_layer, retain_graph=True )[0] else: nll_grads = torch.autograd.grad( nll_loss, self.last_layer[0], retain_graph=True )[0] g_grads = torch.autograd.grad( g_loss, self.last_layer[0], retain_graph=True )[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, cond=None, split='train', predicted_indices=None, ): if not exists(codebook_loss): codebook_loss = torch.tensor([0.0]).to(inputs.device) # rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) rec_loss = self.pixel_loss( inputs.contiguous(), reconstructions.contiguous() ) if self.perceptual_weight > 0: p_loss = self.perceptual_loss( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part if optimizer_idx == 0: # generator update if cond is None: assert not self.disc_conditional logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional logits_fake = self.discriminator( torch.cat((reconstructions.contiguous(), cond), dim=1) ) g_loss = -torch.mean(logits_fake) try: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start, ) loss = ( nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() ) log = { '{}/total_loss'.format(split): loss.clone().detach().mean(), '{}/quant_loss'.format(split): codebook_loss.detach().mean(), '{}/nll_loss'.format(split): nll_loss.detach().mean(), '{}/rec_loss'.format(split): rec_loss.detach().mean(), '{}/p_loss'.format(split): p_loss.detach().mean(), '{}/d_weight'.format(split): d_weight.detach(), '{}/disc_factor'.format(split): torch.tensor(disc_factor), '{}/g_loss'.format(split): g_loss.detach().mean(), } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): perplexity, cluster_usage = measure_perplexity( predicted_indices, self.n_classes ) log[f'{split}/perplexity'] = perplexity log[f'{split}/cluster_usage'] = cluster_usage return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator( reconstructions.contiguous().detach() ) else: logits_real = self.discriminator( torch.cat((inputs.contiguous().detach(), cond), dim=1) ) logits_fake = self.discriminator( torch.cat( (reconstructions.contiguous().detach(), cond), dim=1 ) ) disc_factor = adopt_weight( self.disc_factor, global_step, threshold=self.discriminator_iter_start, ) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = { '{}/disc_loss'.format(split): d_loss.clone().detach().mean(), '{}/logits_real'.format(split): logits_real.detach().mean(), '{}/logits_fake'.format(split): logits_fake.detach().mean(), } return d_loss, log ================================================ FILE: src/stablediffusion/ldm/modules/x_transformer.py ================================================ """shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" import torch from torch import nn, einsum import torch.nn.functional as F from functools import partial from inspect import isfunction from collections import namedtuple from einops import rearrange, repeat, reduce # constants DEFAULT_DIM_HEAD = 64 Intermediates = namedtuple( 'Intermediates', ['pre_softmax_attn', 'post_softmax_attn'] ) LayerIntermediates = namedtuple( 'Intermediates', ['hiddens', 'attn_intermediates'] ) class AbsolutePositionalEmbedding(nn.Module): def __init__(self, dim, max_seq_len): super().__init__() self.emb = nn.Embedding(max_seq_len, dim) self.init_() def init_(self): nn.init.normal_(self.emb.weight, std=0.02) def forward(self, x): n = torch.arange(x.shape[1], device=x.device) return self.emb(n)[None, :, :] class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, x, seq_dim=1, offset=0): t = ( torch.arange(x.shape[seq_dim], device=x.device).type_as( self.inv_freq ) + offset ) sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) return emb[None, :, :] # helpers def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def always(val): def inner(*args, **kwargs): return val return inner def not_equals(val): def inner(x): return x != val return inner def equals(val): def inner(x): return x == val return inner def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max # keyword argument helpers def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) def group_dict_by_key(cond, d): return_val = [dict(), dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) def string_begins_with(prefix, str): return str.startswith(prefix) def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) def groupby_prefix_and_trim(prefix, d): kwargs_with_prefix, kwargs = group_dict_by_key( partial(string_begins_with, prefix), d ) kwargs_without_prefix = dict( map( lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()), ) ) return kwargs_without_prefix, kwargs # classes class Scale(nn.Module): def __init__(self, value, fn): super().__init__() self.value = value self.fn = fn def forward(self, x, **kwargs): x, *rest = self.fn(x, **kwargs) return (x * self.value, *rest) class Rezero(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn self.g = nn.Parameter(torch.zeros(1)) def forward(self, x, **kwargs): x, *rest = self.fn(x, **kwargs) return (x * self.g, *rest) class ScaleNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(1)) def forward(self, x): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale return x / norm.clamp(min=self.eps) * self.g class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-8): super().__init__() self.scale = dim**-0.5 self.eps = eps self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): norm = torch.norm(x, dim=-1, keepdim=True) * self.scale return x / norm.clamp(min=self.eps) * self.g class Residual(nn.Module): def forward(self, x, residual): return x + residual class GRUGating(nn.Module): def __init__(self, dim): super().__init__() self.gru = nn.GRUCell(dim, dim) def forward(self, x, residual): gated_output = self.gru( rearrange(x, 'b n d -> (b n) d'), rearrange(residual, 'b n d -> (b n) d'), ) return gated_output.reshape_as(x) # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = ( nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) ) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) # attention. class Attention(nn.Module): def __init__( self, dim, dim_head=DEFAULT_DIM_HEAD, heads=8, causal=False, mask=None, talking_heads=False, sparse_topk=None, use_entmax15=False, num_mem_kv=0, dropout=0.0, on_attn=False, ): super().__init__() if use_entmax15: raise NotImplementedError( 'Check out entmax activation instead of softmax activation!' ) self.scale = dim_head**-0.5 self.heads = heads self.causal = causal self.mask = mask inner_dim = dim_head * heads self.to_q = nn.Linear(dim, inner_dim, bias=False) self.to_k = nn.Linear(dim, inner_dim, bias=False) self.to_v = nn.Linear(dim, inner_dim, bias=False) self.dropout = nn.Dropout(dropout) # talking heads self.talking_heads = talking_heads if talking_heads: self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) # explicit topk sparse attention self.sparse_topk = sparse_topk # entmax # self.attn_fn = entmax15 if use_entmax15 else F.softmax self.attn_fn = F.softmax # add memory key / values self.num_mem_kv = num_mem_kv if num_mem_kv > 0: self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) # attention on attention self.attn_on_attn = on_attn self.to_out = ( nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) ) def forward( self, x, context=None, mask=None, context_mask=None, rel_pos=None, sinusoidal_emb=None, prev_attn=None, mem=None, ): b, n, _, h, talking_heads, device = ( *x.shape, self.heads, self.talking_heads, x.device, ) kv_input = default(context, x) q_input = x k_input = kv_input v_input = kv_input if exists(mem): k_input = torch.cat((mem, k_input), dim=-2) v_input = torch.cat((mem, v_input), dim=-2) if exists(sinusoidal_emb): # in shortformer, the query would start at a position offset depending on the past cached memory offset = k_input.shape[-2] - q_input.shape[-2] q_input = q_input + sinusoidal_emb(q_input, offset=offset) k_input = k_input + sinusoidal_emb(k_input) q = self.to_q(q_input) k = self.to_k(k_input) v = self.to_v(v_input) q, k, v = map( lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v) ) input_mask = None if any(map(exists, (mask, context_mask))): q_mask = default( mask, lambda: torch.ones((b, n), device=device).bool() ) k_mask = q_mask if not exists(context) else context_mask k_mask = default( k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool(), ) q_mask = rearrange(q_mask, 'b i -> b () i ()') k_mask = rearrange(k_mask, 'b j -> b () () j') input_mask = q_mask * k_mask if self.num_mem_kv > 0: mem_k, mem_v = map( lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v), ) k = torch.cat((mem_k, k), dim=-2) v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): input_mask = F.pad( input_mask, (self.num_mem_kv, 0), value=True ) dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale mask_value = max_neg_value(dots) if exists(prev_attn): dots = dots + prev_attn pre_softmax_attn = dots if talking_heads: dots = einsum( 'b h i j, h k -> b k i j', dots, self.pre_softmax_proj ).contiguous() if exists(rel_pos): dots = rel_pos(dots) if exists(input_mask): dots.masked_fill_(~input_mask, mask_value) del input_mask if self.causal: i, j = dots.shape[-2:] r = torch.arange(i, device=device) mask = rearrange(r, 'i -> () () i ()') < rearrange( r, 'j -> () () () j' ) mask = F.pad(mask, (j - i, 0), value=False) dots.masked_fill_(mask, mask_value) del mask if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: top, _ = dots.topk(self.sparse_topk, dim=-1) vk = top[..., -1].unsqueeze(-1).expand_as(dots) mask = dots < vk dots.masked_fill_(mask, mask_value) del mask attn = self.attn_fn(dots, dim=-1) post_softmax_attn = attn attn = self.dropout(attn) if talking_heads: attn = einsum( 'b h i j, h k -> b k i j', attn, self.post_softmax_proj ).contiguous() out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') intermediates = Intermediates( pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn, ) return self.to_out(out), intermediates class AttentionLayers(nn.Module): def __init__( self, dim, depth, heads=8, causal=False, cross_attend=False, only_cross=False, use_scalenorm=False, use_rmsnorm=False, use_rezero=False, rel_pos_num_buckets=32, rel_pos_max_distance=128, position_infused_attn=False, custom_layers=None, sandwich_coef=None, par_ratio=None, residual_attn=False, cross_residual_attn=False, macaron=False, pre_norm=True, gate_residual=False, **kwargs, ): super().__init__() ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth self.layers = nn.ModuleList([]) self.has_pos_emb = position_infused_attn self.pia_pos_emb = ( FixedPositionalEmbedding(dim) if position_infused_attn else None ) self.rotary_pos_emb = always(None) assert ( rel_pos_num_buckets <= rel_pos_max_distance ), 'number of relative position buckets must be less than the relative position max distance' self.rel_pos = None self.pre_norm = pre_norm self.residual_attn = residual_attn self.cross_residual_attn = cross_residual_attn norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm norm_class = RMSNorm if use_rmsnorm else norm_class norm_fn = partial(norm_class, dim) norm_fn = nn.Identity if use_rezero else norm_fn branch_fn = Rezero if use_rezero else None if cross_attend and not only_cross: default_block = ('a', 'c', 'f') elif cross_attend and only_cross: default_block = ('c', 'f') else: default_block = ('a', 'f') if macaron: default_block = ('f',) + default_block if exists(custom_layers): layer_types = custom_layers elif exists(par_ratio): par_depth = depth * len(default_block) assert 1 < par_ratio <= par_depth, 'par ratio out of range' default_block = tuple(filter(not_equals('f'), default_block)) par_attn = par_depth // par_ratio depth_cut = ( par_depth * 2 // 3 ) # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn assert ( len(default_block) <= par_width ), 'default block is too large for par_ratio' par_block = default_block + ('f',) * ( par_width - len(default_block) ) par_head = par_block * par_attn layer_types = par_head + ('f',) * (par_depth - len(par_head)) elif exists(sandwich_coef): assert ( sandwich_coef > 0 and sandwich_coef <= depth ), 'sandwich coefficient should be less than the depth' layer_types = ( ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef ) else: layer_types = default_block * depth self.layer_types = layer_types self.num_attn_layers = len(list(filter(equals('a'), layer_types))) for layer_type in self.layer_types: if layer_type == 'a': layer = Attention( dim, heads=heads, causal=causal, **attn_kwargs ) elif layer_type == 'c': layer = Attention(dim, heads=heads, **attn_kwargs) elif layer_type == 'f': layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: raise Exception(f'invalid layer type {layer_type}') if isinstance(layer, Attention) and exists(branch_fn): layer = branch_fn(layer) if gate_residual: residual_fn = GRUGating(dim) else: residual_fn = Residual() self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn])) def forward( self, x, context=None, mask=None, context_mask=None, mems=None, return_hiddens=False, **kwargs, ): hiddens = [] intermediates = [] prev_attn = None prev_cross_attn = None mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers for ind, (layer_type, (norm, block, residual_fn)) in enumerate( zip(self.layer_types, self.layers) ): is_last = ind == (len(self.layers) - 1) if layer_type == 'a': hiddens.append(x) layer_mem = mems.pop(0) residual = x if self.pre_norm: x = norm(x) if layer_type == 'a': out, inter = block( x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, prev_attn=prev_attn, mem=layer_mem, ) elif layer_type == 'c': out, inter = block( x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn, ) elif layer_type == 'f': out = block(x) x = residual_fn(out, residual) if layer_type in ('a', 'c'): intermediates.append(inter) if layer_type == 'a' and self.residual_attn: prev_attn = inter.pre_softmax_attn elif layer_type == 'c' and self.cross_residual_attn: prev_cross_attn = inter.pre_softmax_attn if not self.pre_norm and not is_last: x = norm(x) if return_hiddens: intermediates = LayerIntermediates( hiddens=hiddens, attn_intermediates=intermediates ) return x, intermediates return x class Encoder(AttentionLayers): def __init__(self, **kwargs): assert 'causal' not in kwargs, 'cannot set causality on encoder' super().__init__(causal=False, **kwargs) class TransformerWrapper(nn.Module): def __init__( self, *, num_tokens, max_seq_len, attn_layers, emb_dim=None, max_mem_len=0.0, emb_dropout=0.0, num_memory_tokens=None, tie_embedding=False, use_pos_emb=True, ): super().__init__() assert isinstance( attn_layers, AttentionLayers ), 'attention layers must be one of Encoder or Decoder' dim = attn_layers.dim emb_dim = default(emb_dim, dim) self.max_seq_len = max_seq_len self.max_mem_len = max_mem_len self.num_tokens = num_tokens self.token_emb = nn.Embedding(num_tokens, emb_dim) self.pos_emb = ( AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0) ) self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = ( nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() ) self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self.init_() self.to_logits = ( nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() ) # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) self.num_memory_tokens = num_memory_tokens if num_memory_tokens > 0: self.memory_tokens = nn.Parameter( torch.randn(num_memory_tokens, dim) ) # let funnel encoder know number of memory tokens, if specified if hasattr(attn_layers, 'num_memory_tokens'): attn_layers.num_memory_tokens = num_memory_tokens def init_(self): nn.init.normal_(self.token_emb.weight, std=0.02) def forward( self, x, return_embeddings=False, mask=None, return_mems=False, return_attn=False, mems=None, embedding_manager=None, **kwargs, ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens embedded_x = self.token_emb(x) if embedding_manager: x = embedding_manager(x, embedded_x) else: x = embedded_x x = x + self.pos_emb(x) x = self.emb_dropout(x) x = self.project_emb(x) if num_mem > 0: mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) x = torch.cat((mem, x), dim=1) # auto-handle masking after appending memory tokens if exists(mask): mask = F.pad(mask, (num_mem, 0), value=True) x, intermediates = self.attn_layers( x, mask=mask, mems=mems, return_hiddens=True, **kwargs ) x = self.norm(x) mem, x = x[:, :num_mem], x[:, num_mem:] out = self.to_logits(x) if not return_embeddings else x if return_mems: hiddens = intermediates.hiddens new_mems = ( list( map( lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens), ) ) if exists(mems) else hiddens ) new_mems = list( map( lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems ) ) return out, new_mems if return_attn: attn_maps = list( map( lambda t: t.post_softmax_attn, intermediates.attn_intermediates, ) ) return out, attn_maps return out ================================================ FILE: src/stablediffusion/ldm/simplet2i.py ================================================ ''' This module is provided for backward compatibility with the original (hasty) API. Please use ldm.generate instead. ''' from src.stablediffusion.ldm.generate import Generate class T2I(Generate): def __init__(self,**kwargs): print(f'>> The ldm.simplet2i module is deprecated. Use ldm.generate instead. It is a drop-in replacement.') super().__init__(kwargs) ================================================ FILE: src/stablediffusion/ldm/util.py ================================================ import importlib import torch import numpy as np from collections import abc from einops import rearrange from functools import partial import multiprocessing as mp from threading import Thread from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new('RGB', wh, color='white') draw = ImageDraw.Draw(txt) font = ImageFont.load_default() nc = int(40 * (wh[0] / 256)) lines = '\n'.join( xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) ) try: draw.text((0, 0), lines, fill='black', font=font) except UnicodeEncodeError: print('Cant encode string for logging. Skipping.') txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print( f'{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.' ) return total_params def instantiate_from_config(config, **kwargs): if not 'target' in config: if config == '__is_first_stage__': return None elif config == '__is_unconditional__': return None raise KeyError('Expected key `target` to instantiate.') return get_obj_from_str(config['target'])( **config.get('params', dict()), **kwargs ) def get_obj_from_str(string, reload=False): module, cls = string.rsplit('.', 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): # create dummy dataset instance # run prefetching if idx_to_fn: res = func(data, worker_id=idx) else: res = func(data) Q.put([idx, res]) Q.put('Done') def parallel_data_prefetch( func: callable, data, n_proc, target_data_type='ndarray', cpu_intensive=True, use_worker_id=False, ): # if target_data_type not in ["ndarray", "list"]: # raise ValueError( # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # ) if isinstance(data, np.ndarray) and target_data_type == 'list': raise ValueError('list expected but function got ndarray.') elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) data = list(data.values()) if target_data_type == 'ndarray': data = np.asarray(data) else: data = list(data) else: raise TypeError( f'The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.' ) if cpu_intensive: Q = mp.Queue(1000) proc = mp.Process else: Q = Queue(1000) proc = Thread # spawn processes if target_data_type == 'ndarray': arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc)) ] else: step = ( int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) ) arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate( [data[i : i + step] for i in range(0, len(data), step)] ) ] processes = [] for i in range(n_proc): p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) processes += [p] # start processes print(f'Start prefetching...') import time start = time.time() gather_res = [[] for _ in range(n_proc)] try: for p in processes: p.start() k = 0 while k < n_proc: # get result res = Q.get() if res == 'Done': k += 1 else: gather_res[res[0]] = res[1] except Exception as e: print('Exception: ', e) for p in processes: p.terminate() raise e finally: for p in processes: p.join() print(f'Prefetching complete. [{time.time() - start} sec.]') if target_data_type == 'ndarray': if not isinstance(gather_res[0], np.ndarray): return np.concatenate([np.asarray(r) for r in gather_res], axis=0) # order outputs return np.concatenate(gather_res, axis=0) elif target_data_type == 'list': out = [] for r in gather_res: out.extend(r) return out else: return gather_res ================================================ FILE: src/stablediffusion/text2image_compvis.py ================================================ import os import torch import numpy as np from PIL import Image from pytorch_lightning import seed_everything from torch import autocast from src.stablediffusion.ldm.generate import Generate import uuid import shutil # 0 = resize # 1 = crop and resize # 2 = resize and fill def resize_image(resize_mode, im, width, height): LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) if resize_mode == 0: res = im.resize((width, height), resample=LANCZOS) elif resize_mode == 1: ratio = width / height src_ratio = im.width / im.height src_w = width if ratio > src_ratio else im.width * height // im.height src_h = height if ratio <= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) else: ratio = width / height src_ratio = im.width / im.height src_w = width if ratio < src_ratio else im.width * height // im.height src_h = height if ratio >= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) if ratio < src_ratio: fill_height = height // 2 - src_h // 2 res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) return res class Text2Image: def __init__(self, model_path='models/model-epoch06-full.ckpt', use_gpu=True): self.generator = Generate(weights=model_path, config='models/v1-inference.yaml') try: self.generator.load_model() except: import sys, traceback traceback.print_exc(file=sys.stdout) def dream(self, prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int, progress: bool, sampler_name: str): seed = seed_everything(seed) id = str(uuid.uuid4()) results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs') shutil.move(results[0][0], f'storage/outputs/{id}.png') return [Image.open(f'storage/outputs/{id}.png')], results[0][1] def translation(self, prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, sampler_name: str): seed = seed_everything(seed) id = str(uuid.uuid4()) image = init_img.convert("RGB") image = resize_image(1, image, width, height) image.save(f'storage/init/{id}.png') results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs', init_img=f'storage/init/{id}.png', strength=denoising_strength) shutil.move(results[0][0], f'storage/outputs/{id}.png') return [Image.open(f'storage/outputs/{id}.png')], results[0][1] def inpaint(self, prompt: str, init_img, mask_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): seed = seed_everything(seed) id = str(uuid.uuid4()) image = init_img.convert("RGB") image = resize_image(1, image, width, height) image.save(f'storage/init/{id}.png') image_mask = mask_image.convert("RGB") image_mask = resize_image(1, image_mask, width, height) image_mask.save(f'storage/init/{id}-mask.png') results = self.generator.txt2img(prompt=prompt, iterations = 1, steps=ddim_steps, seed=seed, cfg_scale=cfg_scale, ddim_eta=ddim_eta, width=width, height=height, sampler_name=sampler_name, outdir='storage/outputs', init_img=f'storage/init/{id}.png', init_mask=f'storage/init/{id}-mask.png', strength=denoising_strength) shutil.move(results[0][0], f'storage/outputs/{id}.png') return [Image.open(f'storage/outputs/{id}.png')], results[0][1] ================================================ FILE: src/stablediffusion/text2image_diffusers.py ================================================ import os import torch import numpy as np from PIL import Image from pytorch_lightning import seed_everything from torch import autocast from transformers import CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, PNDMScheduler from src.stablediffusion.inpaint import StableDiffusionInpaintingPipeline, preprocess, preprocess_mask from src.stablediffusion.translation import StableDiffusionImg2ImgPipeline from src.stablediffusion.dream import StableDiffusionPipeline # 0 = resize # 1 = crop and resize # 2 = resize and fill def resize_image(resize_mode, im, width, height): LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) if resize_mode == 0: res = im.resize((width, height), resample=LANCZOS) elif resize_mode == 1: ratio = width / height src_ratio = im.width / im.height src_w = width if ratio > src_ratio else im.width * height // im.height src_h = height if ratio <= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) else: ratio = width / height src_ratio = im.width / im.height src_w = width if ratio < src_ratio else im.width * height // im.height src_h = height if ratio >= src_ratio else im.height * width // im.width resized = im.resize((src_w, src_h), resample=LANCZOS) res = Image.new("RGB", (width, height)) res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) if ratio < src_ratio: fill_height = height // 2 - src_h // 2 res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) return res class Text2Image: def __init__(self, use_gpu=True): self.device = torch.device('cuda' if use_gpu else 'cpu') self.dtype = torch.float16 if use_gpu else torch.float32 model_name = 'CompVis/stable-diffusion-v1-4' token = os.environ['HF_TOKEN'] self.vae = AutoencoderKL.from_pretrained(model_name, subfolder='vae', revision="fp16", use_auth_token=token) self.unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", revision="fp16", use_auth_token=token) self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") self.scheduler = LMSDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) self.img2img_scheduler = PNDMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, skip_prk_steps=True ) self.vae = self.vae.to(self.dtype).eval().to(self.device) self.text_encoder = self.text_encoder.to(self.dtype).eval().to(self.device) self.unet = self.unet.to(self.dtype).eval().to(self.device) self.inpaint_pipe = StableDiffusionInpaintingPipeline( self.vae, self.text_encoder, self.tokenizer, self.unet, self.img2img_scheduler ) self.dream_pipe = StableDiffusionPipeline( self.vae, self.text_encoder, self.tokenizer, self.unet, self.scheduler ) self.translation_pipe = StableDiffusionImg2ImgPipeline( self.vae, self.text_encoder, self.tokenizer, self.unet, self.img2img_scheduler ) def dream(self, prompt: str, ddim_steps: int, plms: bool, fixed_code: bool, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, seed: int, height: int, width: int, progress: bool): rng_seed = seed_everything(seed) with autocast('cuda'): image = self.dream_pipe(prompt, height=height, width=width, guidance_scale=cfg_scale, eta=ddim_eta, num_inference_steps=ddim_steps, progress=progress)['sample'] return image, rng_seed def translation(self, prompt: str, init_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): rng_seed = seed_everything(seed) image = init_img.convert("RGB") image = resize_image(1, image, width, height) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) image = 2.0 * image - 1.0 with autocast('cuda'): image = self.translation_pipe(prompt, image, denoising_strength, ddim_steps, cfg_scale, ddim_eta, None, 'pil')['sample'] return image, rng_seed def inpaint(self, prompt: str, init_img, mask_img, ddim_steps: int, ddim_eta: float, n_iter: int, n_samples: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int): rng_seed = seed_everything(seed) init_img = resize_image(1, init_img, width, height) # mask = np.array(init_img.convert('RGBA').split()[-1]) # mask = Image.fromarray(mask) init_img_tensor = preprocess(init_img.convert('RGB')) with autocast('cuda'): image = self.inpaint_pipe(prompt, init_img_tensor, mask_img, denoising_strength, ddim_steps, cfg_scale, ddim_eta, None, 'pil')['sample'] return image, rng_seed @torch.no_grad() def vae_test(self, image, height: int, width: int): image = image.convert("RGB") image = resize_image(1, image, width, height) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) image = 2.0 * image - 1.0 with autocast('cuda'): latent_image = self.vae.decode(self.vae.encode(image.to(self.device)).sample()) latent_image = (latent_image / 2 + 0.5).clamp(0, 1) latent_image = latent_image.cpu().permute(0, 2, 3, 1).numpy() if latent_image.ndim == 3: latent_image = latent_image[None, ...] latent_image = (latent_image * 255).round().astype('uint8') latent_image = [Image.fromarray(image) for image in latent_image] return latent_image ================================================ FILE: src/stablediffusion/translation.py ================================================ import inspect from typing import List, Optional, Union import numpy as np import torch import PIL from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer def preprocess(image): w, h = image.size w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 image = image.resize((w, h), resample=PIL.Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.0 * image - 1.0 class StableDiffusionImg2ImgPipeline(DiffusionPipeline): def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler], ): super().__init__() scheduler = scheduler.set_format("pt") self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler ) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], init_image: torch.FloatTensor, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", ): if isinstance(prompt, str): batch_size = 1 elif isinstance(prompt, list): batch_size = len(prompt) else: raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) extra_set_kwargs = {} offset = 0 if accepts_offset: offset = 1 extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # encode the init image into latents and scale the latents init_latents = self.vae.encode(init_image.to(self.device)).sample() init_latents = 0.18215 * init_latents # prepare init_latents noise to latents init_latents = torch.cat([init_latents] * batch_size) # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) timesteps = self.scheduler.timesteps[-init_timestep] timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device) init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) # get prompt text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: max_length = text_input.input_ids.shape[-1] uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents image = self.vae.decode(latents) image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) return {"sample": image} ================================================ FILE: storage/init/.keep ================================================ ================================================ FILE: storage/outputs/.keep ================================================ ================================================ FILE: win10fix.bat ================================================ python src\scripts\win10patch.py