[
  {
    "path": ".gitignore",
    "content": "# General\n.DS_Store\n.AppleDouble\n.LSOverride\nmodels\n# models/\nlogs\nlogs/\nexps/\ndatasets/\nsrc/\n# Icon must end with two \\r\nIcon\n\n\n# Thumbnails\n._*\n\n# Files that might appear in the root of a volume\n.DocumentRevisions-V100\n.fseventsd\n.Spotlight-V100\n.TemporaryItems\n.Trashes\n.VolumeIcon.icns\n.com.apple.timemachine.donotpresent\n\n# Directories potentially created on remote AFP share\n.AppleDB\n.AppleDesktop\nNetwork Trash Folder\nTemporary Items\n.apdisk\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n*/logs/\n*/wandb/\n*samples/\n\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "## SINE <br><sub> <ins>SIN</ins>gle Image <ins>E</ins>diting with Text-to-Image Diffusion Models</sub>\n\n[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zhang-zx/SINE/blob/master/SINE.ipynb)\n\n\n[Project](https://zhang-zx.github.io/SINE/) |\n[ArXiv](https://arxiv.org/abs/2212.04489) \n\n\nThis respository contains the code for the CVPR 2023 paper [SINE: SINgle Image Editing with Text-to-Image Diffusion Models](https://arxiv.org/abs/2212.04489).\nFor more visualization results, please check our [webpage](https://zhang-zx.github.io/SINE/).\n\n> **[SINE: SINgle Image Editing with Text-to-Image Diffusion Models](https://zhang-zx.github.io/SINE/)** \\\n> [Zhixing Zhang](https://zhang-zx.github.io/) <sup>1</sup>,\n> [Ligong Han](https://phymhan.github.io/) <sup>1</sup>,\n> [Arnab Ghosh](https://arnabgho.github.io/) <sup>2</sup>,\n> [Dimitris Metaxas](https://people.cs.rutgers.edu/~dnm/) <sup>1</sup>,\n> and [Jian Ren](https://alanspike.github.io/) <sup>2</sup> \\\n> <sup>1</sup> Rutgers University\n> <sup>2</sup> Snap Inc.\\\n> CVPR 2023.\n<div align=\"center\">\n    <a><img src=\"assets/overview_finetuning.png\"  width=\"500\" ></a>\n    <a><img src=\"assets/overview_editing.png\"  width=\"500\" ></a>\n</div>\n\n## Setup\n\nFirst, clone the repository and install the dependencies:\n\n```bash\ngit clone git@github.com:zhang-zx/SINE.git\n```\n\nThen, install the dependencies following the [instructions](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1).\n\nAlternatively, you can also try to use the following docker image.\n\n```bash\ndocker pull sunggukcha/sine\n```\n\n\nTo fine-tune the model, you need to download the [pre-trained model](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt).\n\n### Data Preparation\n\nThe data we use in the paper can be found from [here](https://drive.google.com/drive/folders/1rGt5YTCwNgEag8MD_1wr9jrPpi1_8vfu?usp=sharing).\n\n## Fine-tuning\n\n### Fine-tuning w/o patch-based training scheme\n\n```bash\nIMG_PATH=path/to/image\nCLS_WRD='coarse class word'\nNAME='name of the experiment'\n\npython main.py \\\n    --base configs/stable-diffusion/v1-finetune_picture.yaml \\\n    -t --actual_resume /path/to/pre-trained/model \\\n    -n $NAME --gpus 0,  --logdir ./logs \\\n    --data_root $IMG_PATH \\\n    --reg_data_root $IMG_PATH --class_word $CLS_WRD \n```\n\n### Fine-tuning with patch-based training scheme\n\n```bash\nIMG_PATH=path/to/image\nCLS_WRD='coarse class word'\nNAME='name of the experiment'\n\npython main.py \\\n    --base configs/stable-diffusion/v1-finetune_patch_picture.yaml \\\n    -t --actual_resume /path/to/pre-trained/model \\\n    -n $NAME --gpus 0,   --logdir ./logs \\\n    --data_root $IMG_PATH \\\n    --reg_data_root $IMG_PATH --class_word $CLS_WRD  \n```\n\n## Model-based Image Editing\n\n### Editing with one model's guidance\n\n```bash\nLOG_DIR=/path/to/logdir\npython scripts/stable_txt2img_guidance.py --ddim_eta 0.0 --n_iter 1 \\\n    --scale 10 --ddim_steps 100 \\\n    --sin_config configs/stable-diffusion/v1-inference.yaml \\\n    --sin_ckpt $LOG_DIR\"/checkpoints/last.ckpt\" \\\n    --prompt \"prompt for pre-trained model[SEP]prompt for fine-tuned model\" \\\n    --cond_beta 0.4 \\\n    --range_t_min 500 --range_t_max 1000 --single_guidance \\\n    --skip_save --H 512 --W 512 --n_samples 2 \\\n    --outdir $LOG_DIR\n```\n\n### Editing with multiple models' guidance\n\n```bash\npython scripts/stable_txt2img_multi_guidance.py --ddim_eta 0.0 --n_iter 2 \\\n    --scale 10 --ddim_steps 100 \\\n    --sin_ckpt path/to/ckpt1 path/to/ckpt2 \\\n    --sin_config ./configs/stable-diffusion/v1-inference.yaml \\\n    configs/stable-diffusion/v1-inference.yaml \\\n    --prompt \"prompt for pre-trained model[SEP]prompt for fine-tuned model1[SEP]prompt for fine-tuned model2\" \\\n    --beta 0.4 0.5 \\\n    --range_t_min 400 400 --range_t_max 1000 1000 --single_guidance \\\n    --H 512 --W 512 --n_samples 2 \\\n    --outdir path/to/output_dir\n```\n\n## Diffusers library Example\n\nThe Diffusers Library support is still under development.\nResults in our paper are obtained using previous code based on LDM.\n\n### Training\n\n```bash\nexport MODEL_NAME=\"CompVis/stable-diffusion-v1-4\"\nexport IMG_PATH=\"path/to/image\"\nexport OUTPUT_DIR=\"path/to/output_dir\"\n\naccelerate launch diffusers_train.py  \\\n  --pretrained_model_name_or_path=$MODEL_NAME  \\\n  --train_text_encoder \\\n  --img_path=$IMG_PATH \\\n  --output_dir=$OUTPUT_DIR \\\n  --instance_prompt=\"prompt for fine-tuning\" \\\n  --resolution=512 \\\n  --train_batch_size=1 \\\n  --gradient_accumulation_steps=1 \\\n  --learning_rate=1e-6 \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=NUMBERS_OF_STEPS \\\n  --checkpointing_steps=FREQUENCY_FOR_CHECKPOINTING \\\n  --patch_based_training # OPTIONAL: add this flag for patch-based training scheme\n```\n\n### Sampling\n\n```bash\n\npython diffusers_sample.py \\\n--pretrained_model_name_or_path \"path/to/output_dir\" \\\n--prompt \"prompt for fine-tuned model\" \\\n--editing_prompt 'prompt for pre-trained model' \n```\n\n\n## Visualization Results\n\nSome of the editing results are shown below.\nSee more results on our [webpage](https://zhang-zx.github.io/SINE/).\n\n![image](assets/editing.png)\n\n## Acknowledgments\n\nIn this code we refer to the following implementations: [Dreambooth-Stable-Diffusion](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion) and [stable-diffusion](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1).\nImplementation with the Diffusers Library support is highly based on [Dreambooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth).\nGreat thanks to them!\n\n## Reference\n\nIf our work or code helps you, please consider to cite our paper. Thank you!\n\n```BibTeX\n@article{zhang2022sine,\n  title={SINE: SINgle Image Editing with Text-to-Image Diffusion Models},\n  author={Zhang, Zhixing and Han, Ligong and Ghosh, Arnab and Metaxas, Dimitris and Ren, Jian},\n  journal={arXiv preprint arXiv:2212.04489},\n  year={2022}\n}\n```\n"
  },
  {
    "path": "SINE.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Welcome to SINE: SINgle Image Editing with Text-to-Image Diffusion Models!\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"gpu_info = !nvidia-smi\\n\",\n    \"gpu_info = '\\\\n'.join(gpu_info)\\n\",\n    \"if gpu_info.find('failed') >= 0:\\n\",\n    \"    print('Not connected to a GPU')\\n\",\n    \"else:\\n\",\n    \"    print(gpu_info)\\n\",\n    \"\\n\",\n    \"from psutil import virtual_memory\\n\",\n    \"ram_gb = virtual_memory().total / 1e9\\n\",\n    \"print('Your runtime has {:.1f} gigabytes of available RAM\\\\n'.format(ram_gb))\\n\",\n    \"\\n\",\n    \"if ram_gb < 20:\\n\",\n    \"    print('Not using a high-RAM runtime')\\n\",\n    \"else:\\n\",\n    \"    print('You are using a high-RAM runtime!')\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Step 1: Setup required libraries and models. \\n\",\n    \"This may take a few minutes.\\n\",\n    \"\\n\",\n    \"You may optionally enable downloads with pydrive in order to authenticate and avoid drive download limits when fetching the pre-trained model.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"#@title Setup\\n\",\n    \"\\n\",\n    \"import os\\n\",\n    \"\\n\",\n    \"from pydrive.auth import GoogleAuth\\n\",\n    \"from pydrive.drive import GoogleDrive\\n\",\n    \"from google.colab import auth\\n\",\n    \"from oauth2client.client import GoogleCredentials\\n\",\n    \"\\n\",\n    \"from argparse import Namespace\\n\",\n    \"\\n\",\n    \"import sys\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"from PIL import Image\\n\",\n    \"\\n\",\n    \"import torch\\n\",\n    \"import torchvision.transforms as transforms\\n\",\n    \"\\n\",\n    \"device = 'cuda'\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# install requirements\\n\",\n    \"!git clone https://github.com/zhang-zx/SINE.git sine_dir\\n\",\n    \"\\n\",\n    \"%cd sine_dir/\\n\",\n    \"!pip uninstall -y torchtext\\n\",\n    \"! pip install transformers==4.18.0 einops==0.4.1 omegaconf==2.1.1 torchmetrics==0.6.0 torch-fidelity==0.3.0 kornia==0.6 albumentations==1.1.0 opencv-python==4.2.0.34 imageio==2.14.1 setuptools==59.5.0 pillow==9.0.1 \\n\",\n    \"! pip install torch==1.10.2 torchvision==0.11.3\\n\",\n    \"! pip install pytorch-lightning==1.5.9\\n\",\n    \"! pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers\\n\",\n    \"! pip install git+https://github.com/openai/CLIP.git@main#egg=clip\\n\",\n    \"! pip install -e .\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"download_with_pydrive = True \\n\",\n    \"    \\n\",\n    \"class Downloader(object):\\n\",\n    \"    def __init__(self, use_pydrive):\\n\",\n    \"        self.use_pydrive = use_pydrive\\n\",\n    \"\\n\",\n    \"        if self.use_pydrive:\\n\",\n    \"            self.authenticate()\\n\",\n    \"        \\n\",\n    \"    def authenticate(self):\\n\",\n    \"        auth.authenticate_user()\\n\",\n    \"        gauth = GoogleAuth()\\n\",\n    \"        gauth.credentials = GoogleCredentials.get_application_default()\\n\",\n    \"        self.drive = GoogleDrive(gauth)\\n\",\n    \"    \\n\",\n    \"    def download_file(self, file_id, file_dst):\\n\",\n    \"        if self.use_pydrive:\\n\",\n    \"            downloaded = self.drive.CreateFile({'id':file_id})\\n\",\n    \"            downloaded.FetchMetadata(fetch_all=True)\\n\",\n    \"            downloaded.GetContentFile(file_dst)\\n\",\n    \"        else:\\n\",\n    \"            !gdown --id $file_id -O $file_dst\\n\",\n    \"\\n\",\n    \"downloader = Downloader(download_with_pydrive)\\n\",\n    \"\\n\",\n    \"pre_trained_path = os.path.join('models', 'ldm', 'stable-diffusion-v4')\\n\",\n    \"os.makedirs(pre_trained_path, exist_ok=True)\\n\",\n    \"!wget -O models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt \\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Step 2: Download the selected fine-tuned model. \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"finetuned_models_dir = os.path.join('./models', 'finetuned')\\n\",\n    \"os.makedirs(finetuned_models_dir, exist_ok=True)\\n\",\n    \"\\n\",\n    \"orig_image_dir = './dataset'\\n\",\n    \"os.makedirs(orig_image_dir, exist_ok=True)\\n\",\n    \"\\n\",\n    \"source_model_type = 'dog w/o patch-based fine-tuning' #@param['dog w/o patch-based fine-tuning', 'dog w/ patch-based fine-tuning', 'Girl with a peral earring', 'Monalisa', 'castle w/o patch-based fine-tuning', 'castle w/ patch-based fine-tuning']\\n\",\n    \"source_model_download_path = {\\\"dog w/o patch-based fine-tuning\\\":   \\\"1jHgkyxrwUXyMR2zBK9WAEWioEP3-F3fd\\\",\\n\",\n    \"                              \\\"dog w/ patch-based fine-tuning\\\":    \\\"1YI7c29qBIy83OqJ4ykoAAul6P8uaXmls\\\",\\n\",\n    \"                              \\\"Girl with a peral earring\\\":    \\\"1l6GCEfyURKQiCF77ZriYoZRtkOXQCyWD\\\",\\n\",\n    \"                              \\\"Monalisa\\\": \\\"194CDgHkomKrLvgFj89kamoTjUbMwyUIC\\\",\\n\",\n    \"                              \\\"castle w/o patch-based fine-tuning\\\":    \\\"19I8ftab9vMQWnqPH2O7aHe-GnYolmVFF\\\",\\n\",\n    \"                              \\\"castle w/ patch-based fine-tuning\\\":  \\\"1srzUr1fg6jTFKuf0M5oi5JgBsVhCt-nb\\\"}\\n\",\n    \"\\n\",\n    \"model_names = { \\\"dog w/o patch-based fine-tuning\\\":   \\\"dog_wo_patch.ckpt\\\",\\n\",\n    \"                \\\"dog w/ patch-based fine-tuning\\\":    \\\"dog_w_patch.ckpt\\\",\\n\",\n    \"                \\\"Girl with a peral earring\\\":    \\\"girl.ckpt\\\",\\n\",\n    \"                \\\"Monalisa\\\": \\\"monalisa.ckpt\\\",\\n\",\n    \"                \\\"castle w/o patch-based fine-tuning\\\":    \\\"castle_wo_patch\\\",\\n\",\n    \"                \\\"castle w/ patch-based fine-tuning\\\":  \\\"castle_w_patch\\\"}\\n\",\n    \"\\n\",\n    \"model_configs = { \\\"dog w/o patch-based fine-tuning\\\":   \\\"./configs/stable-diffusion/v1-inference.yaml\\\",\\n\",\n    \"                \\\"dog w/ patch-based fine-tuning\\\":    \\\"./configs/stable-diffusion/v1-inference_patch.yaml\\\",\\n\",\n    \"                \\\"Girl with a peral earring\\\":    \\\"./configs/stable-diffusion/v1-inference_patch_nearest.yaml\\\",\\n\",\n    \"                \\\"Monalisa\\\": \\\"./configs/stable-diffusion/v1-inference_patch_nearest.yaml\\\",\\n\",\n    \"                \\\"castle w/o patch-based fine-tuning\\\":    \\\"./configs/stable-diffusion/v1-inference.yaml\\\",\\n\",\n    \"                \\\"castle w/ patch-based fine-tuning\\\":  \\\"./configs/stable-diffusion/v1-inference_patch.yaml\\\"}\\n\",\n    \"\\n\",\n    \"orig_prompts = { \\\"dog w/o patch-based fine-tuning\\\":   \\\"picture of a sks dog\\\",\\n\",\n    \"                \\\"dog w/ patch-based fine-tuning\\\":    \\\"picture of a sks dog\\\",\\n\",\n    \"                \\\"Girl with a peral earring\\\":    \\\"painting of a sks girl\\\",\\n\",\n    \"                \\\"Monalisa\\\": \\\"painting of a sks lady\\\",\\n\",\n    \"                \\\"castle w/o patch-based fine-tuning\\\":    \\\"picture of a sks castle\\\",\\n\",\n    \"                \\\"castle w/ patch-based fine-tuning\\\":  \\\"picture of a sks castle\\\"}\\n\",\n    \"\\n\",\n    \"download_string = source_model_download_path[source_model_type]\\n\",\n    \"file_name = model_names[source_model_type]\\n\",\n    \"\\n\",\n    \"config_name = model_configs[source_model_type]\\n\",\n    \"fine_tune_prompt = orig_prompts[source_model_type]\\n\",\n    \"\\n\",\n    \"if not os.path.isfile(os.path.join(finetuned_models_dir, file_name)):\\n\",\n    \"    downloader.download_file(download_string, os.path.join(finetuned_models_dir, file_name))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# Step3: Edit the image with model-based guidance\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import argparse, os, sys, glob\\n\",\n    \"import torch\\n\",\n    \"import numpy as np\\n\",\n    \"from omegaconf import OmegaConf\\n\",\n    \"from PIL import Image\\n\",\n    \"from tqdm import tqdm, trange\\n\",\n    \"from itertools import islice\\n\",\n    \"from einops import rearrange\\n\",\n    \"from torchvision.utils import make_grid, save_image\\n\",\n    \"import time\\n\",\n    \"from pytorch_lightning import seed_everything\\n\",\n    \"from torch import autocast\\n\",\n    \"from contextlib import contextmanager, nullcontext\\n\",\n    \"\\n\",\n    \"from ldm.util import instantiate_from_config\\n\",\n    \"from ldm.models.diffusion.ddim import DDIMSampler\\n\",\n    \"from ldm.models.diffusion.plms import PLMSSampler\\n\",\n    \"from IPython.display import display\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def chunk(it, size):\\n\",\n    \"    it = iter(it)\\n\",\n    \"    return iter(lambda: tuple(islice(it, size)), ())\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def load_model_from_config(config, ckpt, verbose=False):\\n\",\n    \"    print(f\\\"Loading model from {ckpt}\\\")\\n\",\n    \"    pl_sd = torch.load(ckpt, map_location=\\\"cpu\\\")\\n\",\n    \"    if \\\"global_step\\\" in pl_sd:\\n\",\n    \"        print(f\\\"Global Step: {pl_sd['global_step']}\\\")\\n\",\n    \"    sd = pl_sd[\\\"state_dict\\\"]\\n\",\n    \"    model = instantiate_from_config(config.model)\\n\",\n    \"    m, u = model.load_state_dict(sd, strict=False)\\n\",\n    \"    if len(m) > 0 and verbose:\\n\",\n    \"        print(\\\"missing keys:\\\")\\n\",\n    \"        print(m)\\n\",\n    \"    if len(u) > 0 and verbose:\\n\",\n    \"        print(\\\"unexpected keys:\\\")\\n\",\n    \"        print(u)\\n\",\n    \"\\n\",\n    \"    model.cuda()\\n\",\n    \"    model.eval()\\n\",\n    \"    return model\\n\",\n    \"\\n\",\n    \"seed = 42\\n\",\n    \"config = OmegaConf.load('configs/stable-diffusion/v1-inference.yaml')\\n\",\n    \"model = load_model_from_config(config, 'models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt')\\n\",\n    \"\\n\",\n    \"device = torch.device(\\\"cuda\\\") if torch.cuda.is_available() else torch.device(\\\"cpu\\\")\\n\",\n    \"model = model.to(device)\\n\",\n    \"\\n\",\n    \"sin_config = OmegaConf.load(f\\\"{config_name}\\\")\\n\",\n    \"sin_model = load_model_from_config(config, os.path.join(finetuned_models_dir, file_name))\\n\",\n    \"sin_model = sin_model.to(device)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"v = 0.7 #@param {type:\\\"slider\\\", min:0, max:1, step:0.05}\\n\",\n    \"K_min = 400 #@param {type:\\\"slider\\\", min:0, max:1000, step:10}\\n\",\n    \"scale = 7.5 #@param {type:\\\"slider\\\", min:1.0, max:50, step:0.5}\\n\",\n    \"ddim_steps = 100\\n\",\n    \"ddim_eta = 0.\\n\",\n    \"H = 512\\n\",\n    \"W = 512\\n\",\n    \"\\n\",\n    \"prompt = \\\"a dog wearing a superhero cape\\\" #@param {'type': 'string'}\\n\",\n    \"\\n\",\n    \"extra_config = {\\n\",\n    \"    'cond_beta': v,\\n\",\n    \"    'cond_beta_sin': 1. - v,\\n\",\n    \"    'range_t_max': 1000,\\n\",\n    \"    'range_t_min': K_min\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from ldm.models.diffusion.guidance_ddim import DDIMSinSampler\\n\",\n    \"sampler = DDIMSinSampler(model, sin_model)\\n\",\n    \"\\n\",\n    \"setattr(sampler.model, 'extra_config', extra_config)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"batch_size = 1\\n\",\n    \"n_rows = 2\\n\",\n    \"start_code = None\\n\",\n    \"precision_scope = autocast\\n\",\n    \"num_samples = 4\\n\",\n    \"\\n\",\n    \"all_samples = list()\\n\",\n    \"\\n\",\n    \"with torch.no_grad():\\n\",\n    \"    with precision_scope(\\\"cuda\\\"):\\n\",\n    \"        with model.ema_scope():\\n\",\n    \"            with sin_model.ema_scope():\\n\",\n    \"                tic = time.time()\\n\",\n    \"                all_samples = list()\\n\",\n    \"                for n in trange(num_samples, desc=\\\"Sampling\\\"):   \\n\",\n    \"                    uc = None\\n\",\n    \"                    if scale != 1.0:\\n\",\n    \"                        uc = model.get_learned_conditioning(batch_size * [\\\"\\\"])\\n\",\n    \"                        uc_sin = sin_model.get_learned_conditioning(batch_size * [\\\"\\\"])\\n\",\n    \"\\n\",\n    \"                    prompts = [prompt] * batch_size\\n\",\n    \"                    prompts_single = [fine_tune_prompt] * batch_size\\n\",\n    \"                    \\n\",\n    \"                    c = model.get_learned_conditioning(prompts)\\n\",\n    \"                    c_sin = sin_model.get_learned_conditioning(prompts_single)\\n\",\n    \"                    \\n\",\n    \"                    shape = [4, H // 8, W // 8]\\n\",\n    \"                    samples_ddim, _ = sampler.sample( S=ddim_steps,\\n\",\n    \"                                                      conditioning=c,\\n\",\n    \"                                                      conditioning_single=c_sin,\\n\",\n    \"                                                      batch_size=batch_size,\\n\",\n    \"                                                      shape=shape,\\n\",\n    \"                                                      verbose=False,\\n\",\n    \"                                                      unconditional_guidance_scale=scale,\\n\",\n    \"                                                      unconditional_conditioning=uc,\\n\",\n    \"                                                      unconditional_conditioning_single=uc_sin,\\n\",\n    \"                                                      eta=ddim_eta,\\n\",\n    \"                                                      x_T=start_code)\\n\",\n    \"\\n\",\n    \"                    x_samples_ddim = model.decode_first_stage(samples_ddim)\\n\",\n    \"                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\\n\",\n    \"\\n\",\n    \"                    all_samples.append(x_samples_ddim)\\n\",\n    \"\\n\",\n    \"                grid = torch.stack(all_samples, 0)\\n\",\n    \"                grid = rearrange(grid, 'n b c h w -> (n b) c h w')\\n\",\n    \"                \\n\",\n    \"                grid = make_grid(grid, nrow=n_rows)\\n\",\n    \"\\n\",\n    \"                # to image\\n\",\n    \"                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\\n\",\n    \"                os.makedirs('./output', exist_ok=True)\\n\",\n    \"                Image.fromarray(grid.astype(np.uint8)).save(os.path.join('./output', f'{prompt.replace(\\\" \\\", \\\"-\\\")}.jpg'))\\n\",\n    \"                display(Image.open(os.path.join('./output', f'{prompt.replace(\\\" \\\", \\\"-\\\")}.jpg')))\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3.8.0 ('SINE')\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.0\"\n  },\n  \"orig_nbformat\": 4,\n  \"vscode\": {\n   \"interpreter\": {\n    \"hash\": \"a84f578b9cda1db545aa6690161d7775d6ea32a647f25bb9ef4866c136688289\"\n   }\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "configs/stable-diffusion/v1-finetune_painting.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    reg_weight: 0.0\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: image\n    cond_stage_key: caption\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: true   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n    embedding_reg_weight: 0.0\n    unfreeze_model: True\n    model_lr: 1.0e-06\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 512\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 1\n    num_workers: 2\n    wrap: false\n    train:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.0\n    reg:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.0\n        \n    validation:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: val\n        per_image_tokens: false\n        repeats: 10\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 500\n  callbacks:\n    image_logger:\n      target: main.ImageLogger\n      params:\n        batch_frequency: 500\n        max_images: 8\n        increase_log_steps: False\n\n  trainer:\n    benchmark: True\n    max_steps: 800\n"
  },
  {
    "path": "configs/stable-diffusion/v1-finetune_painting_style.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    reg_weight: 0.0\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: image\n    cond_stage_key: caption\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: true   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n    embedding_reg_weight: 0.0\n    unfreeze_model: True\n    model_lr: 1.0e-06\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 512\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 1\n    num_workers: 2\n    wrap: false\n    train:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.0\n        learn_style: true\n    reg:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.0\n        \n    validation:\n      target: ldm.data.personalized_painting.SinImageDataset\n      params:\n        size: 512\n        set: val\n        per_image_tokens: false\n        repeats: 10\n        learn_style: true\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 500\n  callbacks:\n    image_logger:\n      target: main.ImageLogger\n      params:\n        batch_frequency: 500\n        max_images: 8\n        increase_log_steps: False\n\n  trainer:\n    benchmark: True\n    max_steps: 800\n"
  },
  {
    "path": "configs/stable-diffusion/v1-finetune_patch_painting.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    reg_weight: 0.0\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: image\n    cond_stage_key: caption\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: true   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n    embedding_reg_weight: 0.0\n    unfreeze_model: True\n    model_lr: 1.0e-6\n    scale_recon_loss: 1.0\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n        padding_idx: 0\n        init_size: 128 \n        div_half_dim: false\n        center_shift: 100\n        interpolation_mode: \"nearest\"\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 512\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 1\n    num_workers: 2\n    wrap: false\n    train:\n      target: ldm.data.personalized_painting.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        min_crop_frac: 0.1\n        max_crop_frac: 1.0\n        rec_prob: 1.\n        latent_scale: 8\n    reg:\n      target: ldm.data.personalized_painting.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: train\n        per_image_tokens: false\n        repeats: 1\n        min_crop_frac: 0.1\n        max_crop_frac: 1.0\n        rec_prob: 0.\n        latent_scale: 8\n        \n    validation:\n      target: ldm.data.personalized_painting.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: val\n        per_image_tokens: false\n        repeats: 10\n        min_crop_frac: 0.2\n        max_crop_frac: 1.0\n        rec_prob: 0.25\n        latent_scale: 8\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 2000\n  callbacks:\n    image_logger:\n      target: main.ImageLogger\n      params:\n        batch_frequency: 2000\n        max_images: 8\n        increase_log_steps: False\n\n  trainer:\n    benchmark: True\n    max_steps: 10000\n"
  },
  {
    "path": "configs/stable-diffusion/v1-finetune_patch_picture.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    reg_weight: 0.0\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: image\n    cond_stage_key: caption\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: true   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n    embedding_reg_weight: 0.0\n    unfreeze_model: True\n    model_lr: 1.0e-6\n    scale_recon_loss: 1.0\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n        padding_idx: 0\n        init_size: 128 \n        div_half_dim: false\n        center_shift: 100\n        interpolation_mode: \"bilinear\"\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 512\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 1\n    num_workers: 2\n    wrap: false\n    train:\n      target: ldm.data.personalized.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        min_crop_frac: 0.1\n        max_crop_frac: 1.0\n        rec_prob: 1.\n        latent_scale: 8\n    reg:\n      target: ldm.data.personalized.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: train\n        per_image_tokens: false\n        repeats: 1\n        min_crop_frac: 0.1\n        max_crop_frac: 1.0\n        rec_prob: 0.\n        latent_scale: 8\n        \n    validation:\n      target: ldm.data.personalized.SinImageHighResDataset\n      params:\n        size: 512\n        high_resolution: 1024\n        set: val\n        per_image_tokens: false\n        repeats: 10\n        min_crop_frac: 0.2\n        max_crop_frac: 1.0\n        rec_prob: 0.25\n        latent_scale: 8\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 7000\n  callbacks:\n    image_logger:\n      target: main.ImageLogger\n      params:\n        batch_frequency: 2000\n        max_images: 8\n        increase_log_steps: False\n\n  trainer:\n    benchmark: True\n    max_steps: 10000\n"
  },
  {
    "path": "configs/stable-diffusion/v1-finetune_picture.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    reg_weight: 0.0\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: image\n    cond_stage_key: caption\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: true   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n    embedding_reg_weight: 0.0\n    unfreeze_model: True\n    model_lr: 1.0e-06\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 512\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 1\n    num_workers: 2\n    wrap: false\n    train:\n      target: ldm.data.personalized.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.\n    reg:\n      target: ldm.data.personalized.SinImageDataset\n      params:\n        size: 512\n        set: train\n        per_image_tokens: false\n        repeats: 100\n        flip_p: 0.\n        \n    validation:\n      target: ldm.data.personalized.SinImageDataset\n      params:\n        size: 512\n        set: val\n        per_image_tokens: false\n        repeats: 10\n\nlightning:\n  modelcheckpoint:\n    params:\n      every_n_train_steps: 800\n  callbacks:\n    image_logger:\n      target: main.ImageLogger\n      params:\n        batch_frequency: 500\n        max_images: 8\n        increase_log_steps: False\n\n  trainer:\n    benchmark: True\n    max_steps: 800\n"
  },
  {
    "path": "configs/stable-diffusion/v1-inference.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n        \n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder"
  },
  {
    "path": "configs/stable-diffusion/v1-inference_patch.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n        \n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n        padding_idx: 0\n        init_size: 128 ## Note: Might be some problem in this line. Might need to be 1024\n        div_half_dim: false\n        center_shift: 100\n        interpolation_mode: \"bilinear\" # bilinear or nearest supported\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder"
  },
  {
    "path": "configs/stable-diffusion/v1-inference_patch_nearest_interp.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    personalization_config:\n      target: ldm.modules.embedding_manager.EmbeddingManager\n      params:\n        placeholder_strings: [\"*\"]\n        initializer_words: [\"sculpture\"]\n        per_image_tokens: false\n        num_vectors_per_token: 1\n        progressive_words: False\n        \n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: True\n        legacy: False\n        padding_idx: 0\n        init_size: 128 ## Note: Might be some problem in this line. Might need to be 1024\n        div_half_dim: false\n        center_shift: 100\n        interpolation_mode: \"nearest\" # bilinear or nearest supported\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder"
  },
  {
    "path": "diffusers_models.py",
    "content": "from diffusers import UNet2DConditionModel\nfrom diffusers.models.unet_2d_condition import UNet2DConditionOutput\nfrom diffusers.configuration_utils import register_to_config\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\nfrom pydantic import StrictInt, StrictFloat, StrictBool, StrictStr\n\nimport torch\nimport torch.utils.checkpoint\nimport torch.nn.functional as F\n\nfrom ldm.modules.diffusionmodules.positional_encoding import SinusoidalPositionalEmbedding\n\n\nclass UNet2DConditionPatchModel(UNet2DConditionModel):\n    @register_to_config\n    def __init__(\n        self,\n\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str] = (\n            \"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        time_embedding_type: str = \"positional\",  # fourier, positional\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n\n        padding_idx: StrictInt = 0,\n        init_size: StrictInt = 128,\n        div_half_dim: StrictBool = False,\n        center_shift: StrictInt = 64,\n        interpolation_mode: StrictStr = \"bilinear\",\n    ):\n        super().__init__(sample_size=sample_size,\n                         in_channels=in_channels,\n                         out_channels=out_channels,\n                         center_input_sample=center_input_sample,\n                         flip_sin_to_cos=flip_sin_to_cos,\n                         freq_shift=freq_shift,\n                         down_block_types=down_block_types,\n                         mid_block_type=mid_block_type,\n                         up_block_types=up_block_types,\n                         only_cross_attention=only_cross_attention,\n                         block_out_channels=block_out_channels,\n                         layers_per_block=layers_per_block,\n                         downsample_padding=downsample_padding,\n                         mid_block_scale_factor=mid_block_scale_factor,\n                         act_fn=act_fn,\n                         norm_num_groups=norm_num_groups,\n                         norm_eps=norm_eps,\n                         cross_attention_dim=cross_attention_dim,\n                         attention_head_dim=attention_head_dim,\n                         dual_cross_attention=dual_cross_attention,\n                         use_linear_projection=use_linear_projection,\n                         class_embed_type=class_embed_type,\n                         num_class_embeds=num_class_embeds,\n                         upcast_attention=upcast_attention,\n                         resnet_time_scale_shift=resnet_time_scale_shift,\n                         time_embedding_type=time_embedding_type,  # fourier, positional\n                         timestep_post_act=timestep_post_act,\n                         time_cond_proj_dim=time_cond_proj_dim,\n                         conv_in_kernel=conv_in_kernel,\n                         conv_out_kernel=conv_out_kernel,\n                         projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,\n                         )\n        assert block_out_channels[0] % 2 == 0\n        self.head_position_encode = SinusoidalPositionalEmbedding(embedding_dim=block_out_channels[0]//2,\n                                                                  padding_idx=padding_idx,\n                                                                  init_size=init_size,\n                                                                  div_half_dim=div_half_dim,\n                                                                  center_shift=center_shift)\n        self.init_size = init_size\n        self.interpolation_mode = interpolation_mode\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        crop_boxes: Optional[torch.Tensor] = None,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        return_dict: bool = True,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            forward_upsample_size = True\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor(\n                [timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=self.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\n                    \"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        head_grid = self.head_position_encode(torch.ones([sample.shape[0], sample.shape[1], self.init_size, self.init_size], dtype=self.dtype,\n                                                         device=sample.device))\n\n        if crop_boxes is not None:\n\n            head_grid = torch.cat([F.interpolate(hg.unsqueeze(0)[:, :, box[0]: box[2], box[1]: box[3]],\n                                                 (sample.shape[2], sample.shape[3]), mode='bilinear', align_corners=True)\n                                   for hg, box in\n                                   zip(head_grid, crop_boxes)], dim=0)\n        else:\n            head_grid = F.interpolate(\n                head_grid, (sample.shape[2], sample.shape[3]), mode='bilinear', align_corners=True)\n\n        sample += head_grid\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                )\n            else:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n            )\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets):]\n            down_block_res_samples = down_block_res_samples[: -len(\n                upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet2DConditionOutput(sample=sample)\n\n\nif __name__ == \"__main__\":\n    unet = UNet2DConditionPatchModel.from_pretrained(\n        \"CompVis/stable-diffusion-v1-4\", subfolder=\"unet\", revision=None, low_cpu_mem_usage=False, device_map=None\n    )"
  },
  {
    "path": "diffusers_sample.py",
    "content": "from diffusers import StableDiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nimport torch\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel\nfrom transformers import AutoTokenizer, PretrainedConfig\nfrom typing import Any, Callable, Dict, List, Optional, Union\nimport importlib\nimport os \nimport diffusers\nfrom diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT\nimport argparse\nfrom accelerate.utils import ProjectConfiguration, set_seed\n\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n    \n    \nclass StableDiffusionGuidancePipeline(StableDiffusionPipeline):\n    text_encoder_orig = None\n    unet_orig = None\n    def __init__(\n        self,\n        vae,\n        text_encoder,\n        tokenizer,\n        unet,\n        scheduler,\n        safety_checker,\n        feature_extractor,\n        requires_safety_checker,\n    ):\n        super().__init__(vae,\n            text_encoder,\n            tokenizer,\n            unet,\n            scheduler,\n            safety_checker,\n            feature_extractor,\n            requires_safety_checker,)\n        self.config['unet'] = (unet.__module__, unet.config['_class_name'])\n        \n        \n    def add_pretrained_model(self, text_encoder, unet):\n        self.text_encoder_orig = text_encoder.to(self._execution_device)\n        self.unet_orig = unet.to(self._execution_device)\n        \n    @classmethod\n    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):\n        torch_dtype = kwargs.pop(\"torch_dtype\", None)\n        provider = kwargs.pop(\"provider\", None)\n        sess_options = kwargs.pop(\"sess_options\", None)\n        device_map = kwargs.pop(\"device_map\", None)\n        low_cpu_mem_usage = kwargs.pop(\"low_cpu_mem_usage\", _LOW_CPU_MEM_USAGE_DEFAULT)\n        return_cached_folder = kwargs.pop(\"return_cached_folder\", False)\n\n        # 1. Download the checkpoints and configs\n        # use snapshot download here to get it working from from_pretrained\n        cached_folder = pretrained_model_name_or_path\n\n        config_dict = cls.load_config(cached_folder)\n        \n        if config_dict['unet'][0] is None:\n            config_dict['unet'][0] = 'diffusers_models'\n\n        # 2. Load the pipeline class\n        pipeline_class = cls\n\n        # some modules can be passed directly to the init\n        # in this case they are already instantiated in `kwargs`\n        # extract them here\n        expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)\n        passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}\n        passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}\n\n        init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)\n\n        # define init kwargs\n        init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}\n        init_kwargs = {**init_kwargs, **passed_pipe_kwargs}\n\n        # remove `null` components\n        def load_module(name, value):\n            if isinstance(value, bool):\n                return False\n            if value[0] is None:\n                return False\n            if name in passed_class_obj and passed_class_obj[name] is None:\n                return False\n            return True\n\n        init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}\n\n        # import it here to avoid circular import\n        from diffusers import pipelines\n\n        # 3. Load each module in the pipeline\n        for name, (library_name, class_name) in init_dict.items():\n            # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names\n            if class_name.startswith(\"Flax\"):\n                class_name = class_name[4:]\n\n            is_pipeline_module = hasattr(pipelines, library_name)\n            loaded_sub_model = None\n\n            # if the model is in a pipeline module, then we load it from the pipeline\n            if name in passed_class_obj:\n                # 1. check that passed_class_obj has correct parent class\n                # pass\n\n                # set passed class object\n                loaded_sub_model = passed_class_obj[name]\n            elif is_pipeline_module:\n                pipeline_module = getattr(pipelines, library_name)\n                class_obj = getattr(pipeline_module, class_name)\n            else:\n                # else we just import it from the library.\n                # NOTE: here I reuse library_name as the module name\n                library = importlib.import_module(library_name)\n                class_obj = getattr(library, class_name)\n            \n            if loaded_sub_model is None:\n                load_method_name = 'from_pretrained'\n\n                load_method = getattr(class_obj, load_method_name)\n                loading_kwargs = {}\n\n                if issubclass(class_obj, torch.nn.Module):\n                    loading_kwargs[\"torch_dtype\"] = torch_dtype\n                # if issubclass(class_obj, diffusers.OnnxRuntimeModel):\n                #     loading_kwargs[\"provider\"] = provider\n                #     loading_kwargs[\"sess_options\"] = sess_options\n\n                is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)\n\n                # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.\n                # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.\n                # This makes sure that the weights won't be initialized which significantly speeds up loading.\n                if is_diffusers_model:\n                    loading_kwargs[\"device_map\"] = device_map\n                    loading_kwargs[\"low_cpu_mem_usage\"] = low_cpu_mem_usage\n\n                # check if the module is in a subdirectory\n                if os.path.isdir(os.path.join(cached_folder, name)):\n                    loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)\n                else:\n                    # else load from the root directory\n                    loaded_sub_model = load_method(cached_folder, **loading_kwargs)\n\n            init_kwargs[name] = loaded_sub_model  # UNet(...), # DiffusionSchedule(...)\n\n        init_kwargs['requires_safety_checker'] = False\n        # 4. Potentially add passed objects if expected\n        missing_modules = set(expected_modules) - set(init_kwargs.keys())\n        passed_modules = list(passed_class_obj.keys())\n        optional_modules = pipeline_class._optional_components\n        if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):\n            for module in missing_modules:\n                init_kwargs[module] = passed_class_obj.get(module, None)\n        elif len(missing_modules) > 0:\n            passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs\n            raise ValueError(\n                f\"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed.\"\n            )\n\n        # 5. Instantiate the pipeline\n        model = pipeline_class(**init_kwargs)\n\n        if return_cached_folder:\n            return model, cached_folder\n        return model\n        \n    def _encode_prompt_orig(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.\n                Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n\n            if hasattr(self.text_encoder_orig.config, \"use_attention_mask\") and self.text_encoder_orig.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder_orig(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_orig.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder_orig.config, \"use_attention_mask\") and self.text_encoder_orig.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder_orig(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_orig.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n    \n    @torch.no_grad()\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_edit: Optional[Union[str, List[str]]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        model_based_guidance_scale: float = 0.0,\n        K_min: int = 400,\n    ):\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n        self.check_inputs(\n            prompt_edit, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds\n        )\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n            assert isinstance(prompt_edit, str)\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n            assert isinstance(prompt_edit, list)\n            assert len(prompt_edit) == len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input prompt\n        prompt_embeds = self._encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n        )\n        prompt_embdes_edit = self._encode_prompt_orig(\n            prompt_edit,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=None,\n            negative_prompt_embeds=None,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.unet.in_channels\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet_orig(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embdes_edit,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                ).sample\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                if t > K_min and model_based_guidance_scale > 0.0:\n                     \n                    noise_pred_guidance = self.unet(\n                        latent_model_input,\n                        t,\n                        encoder_hidden_states=prompt_embeds,\n                        cross_attention_kwargs=cross_attention_kwargs,\n                    ).sample\n                    if do_classifier_free_guidance:\n                        noise_pred_guidance_uncond, noise_pred_guidance_text = noise_pred_guidance.chunk(2)\n                        noise_pred_text = noise_pred_text * (1 - model_based_guidance_scale) + noise_pred_guidance_text * model_based_guidance_scale\n                    else:\n                        noise_pred = noise_pred * (1 - model_based_guidance_scale) + noise_pred_guidance * model_based_guidance_scale\n                \n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if output_type == \"latent\":\n            image = latents\n        elif output_type == \"pil\":\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n            # 10. Convert to PIL\n            image = self.numpy_to_pil(image)\n        else:\n            # 8. Post-processing\n            image = self.decode_latents(latents)\n\n        # Offload last model to CPU\n        if hasattr(self, \"final_offload_hook\") and self.final_offload_hook is not None:\n            self.final_offload_hook.offload()\n\n        if not return_dict:\n            return (image, None)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)\n    \ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(\n        description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        '--prompt',\n        type=str,\n        default=None,\n        required=True,\n        help='Text prompt for the fine-tuned model.'\n    )\n    parser.add_argument(\n        '--editing_prompt',\n        type=str,\n        default=None,\n        required=True,\n        help='Text prompt for the pre-trained model.'\n    )\n    parser.add_argument(\"--seed\", type=int, default=412441,\n                        help=\"A seed for reproducible sampling.\")\n    parser.add_argument('--num_images_per_prompt', type=int, default=2, help='Batch size.')\n    parser.add_argument('--num_iterations', type=int, default=1,)\n    parser.add_argument('--model_based_guidance_scale', type=float, default=0.3, help='Scale of model-based guidance.')\n    parser.add_argument('--guidance_scale', type=float, default=7.5, help='Scale of classifier-free guidance.')\n    parser.add_argument('--K', default=400, type=int, help='step to stop guidance')\n    parser.add_argument('--ddim_steps', default=100, type=int, help='Number of ddim steps')\n    parser.add_argument('--height', default=512, type=int, help='Height of the image')\n    parser.add_argument('--width', default=512, type=int, help='Width of the image')\n    \n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    return args\n\n    \nif __name__ == \"__main__\":\n    args = parse_args()\n    set_seed(args.seed)\n    \n    pretrained_model_name_or_path = \"CompVis/stable-diffusion-v1-4\"\n\n    text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, None)\n    text_encoder = text_encoder_cls.from_pretrained(\n            pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=None\n        )\n\n    unet = UNet2DConditionModel.from_pretrained(\n        pretrained_model_name_or_path, subfolder=\"unet\", revision=None\n    )\n    \n    \n    model_id = args.pretrained_model_name_or_path\n    pipe = StableDiffusionGuidancePipeline.from_pretrained(model_id, torch_dtype=torch.float).to(\"cuda\")\n    pipe.add_pretrained_model(text_encoder=text_encoder, unet=unet)\n\n    prompt = args.prompt\n    prompt_edit = args.editing_prompt\n    file_name = (prompt + '[SEP]' + prompt_edit).replace(' ', '_')\n    for i in range(args.num_iterations):\n        images = pipe(prompt=prompt, prompt_edit=prompt_edit,\n                    model_based_guidance_scale=args.model_based_guidance_scale, \n                    K_min=args.K, num_inference_steps=args.ddim_steps, \n                    guidance_scale=args.guidance_scale, \n                    num_images_per_prompt=args.num_images_per_prompt, height=args.height, width=args.width).images\n\n        for j, image in enumerate(images):\n\n            image.save(\"{}/{}_{}.png\".format(args.pretrained_model_name_or_path, file_name, i * args.num_images_per_prompt + j))"
  },
  {
    "path": "diffusers_train.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2022 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport hashlib\nimport itertools\nimport logging\nimport math\nimport os\nimport warnings\nfrom pathlib import Path\nfrom typing import Optional\nimport random\nfrom copy import deepcopy\n\nimport accelerate\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import HfFolder, Repository, create_repo, whoami\nfrom packaging import version\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\nimport numpy as np\n\nimport diffusers\nfrom diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version\nfrom diffusers.utils.import_utils import is_xformers_available\n\nfrom diffusers_models import UNet2DConditionPatchModel\n\n\nlogger = get_logger(__name__)\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(\n        description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    parser.add_argument(\n        \"--img_path\",\n        type=str,\n        default=None,\n        required=True,\n    )\n    parser.add_argument(\n        \"--instance_prompt\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"The prompt with identifier specifying the instance\",\n    )\n\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"text-inversion-model\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=412441,\n                        help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_text_encoder\",\n        action=\"store_true\",\n        help=\"Whether to train the text encoder. If set, the text encoder should be float32 precision.\",\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\n            \"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`.\"\n            \" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state\"\n            \" for more details\"\n        ),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-6,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0,\n                        help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9,\n                        help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999,\n                        help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float,\n                        default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08,\n                        help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0,\n                        type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\",\n                        help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None,\n                        help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--prior_generation_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp32\", \"fp16\", \"bf16\"],\n        help=(\n            \"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to  fp16 if a GPU is available else fp32.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1,\n                        help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    \n    parser.add_argument(\n        \"--patch_based_training\",\n        action=\"store_true\",\n        help=(\n            \"activate patch based training\"\n        ),\n    )\n    parser.add_argument(\n        '--high_res',\n        type=int,\n        default=1024,\n        help='the highest resolution provided to the dataloader'\n    )\n    parser.add_argument(\n        '--latent_scale',\n        type=int,\n        default=8,\n        help='the scale of the latent space'\n    )\n    parser.add_argument(\n        '--min_crop_frac',\n        type=float,\n        default=0.1,\n        help='the minimum fraction of the image to crop'\n    )\n    parser.add_argument(\n        '--max_crop_frac',\n        type=float,\n        default=1,\n        help='the maximum fraction of the image to crop'\n    )\n    parser.add_argument(\n        '--rec_prob',\n        type=float,\n        default=0.1,\n        help='the probability of using the whole image as the crop'\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\nclass SINEDatasetPatch(Dataset):\n    def __init__(\n        self,\n        img_path,\n        instance_prompt,\n        tokenizer,\n        size=512,\n        high_res=1024,\n        min_crop_frac=0.1,\n        max_crop_frac=1,\n        rec_prob=0.1,\n        latent_scale=8,\n    ):\n        super().__init__()\n        self.size = size\n        self.tokenizer = tokenizer\n        self.img_path = Path(img_path)\n        if not self.img_path.exists():\n            raise ValueError(f\"Image {self.img_path} doesn't exists.\")\n        self.num_instance_images = 1\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        self.image = Image.open(img_path)\n        if not self.image.mode == \"RGB\":\n            self.image = self.image.convert(\"RGB\")\n            \n        self.image = self.image.resize((high_res, high_res), resample=Image.Resampling.BICUBIC)\n        \n        self.instance_prompt_ids = self.tokenizer(\n            self.instance_prompt,\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids\n\n        self.high_res = high_res\n        self.min_crop_frac = min_crop_frac\n        self.max_crop_frac = max_crop_frac\n        self.rec_prob = rec_prob\n        self.latent_scale = latent_scale\n        \n        self.image_transform = transforms.Compose(\n            [\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n        \n    def __len__(self):\n        return self._length * 20\n        \n    def _random_crop(self, pil_image):\n        patch_size_y = int(\n            (self.high_res//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        patch_size_x = int(\n            (self.high_res//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        crop_y = random.randrange((self.high_res//self.latent_scale) - patch_size_y + 1)\n        crop_x = random.randrange((self.high_res//self.latent_scale) - patch_size_x + 1)\n        return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(\n            (self.size, self.size),\n            resample=Image.Resampling.BICUBIC), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x\n        \n    def __getitem__(self, i):\n        example = {}\n        image = deepcopy(self.image)\n\n        if random.random() > self.rec_prob:\n            image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)\n            crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])\n        else:\n            image = image.resize((self.size, self.size), resample=Image.Resampling.BICUBIC )\n            crop_area = torch.tensor([0, 0, self.high_res//self.latent_scale, self.high_res//self.latent_scale])\n            \n        \n        example = {\n            \"instance_images\": self.image_transform(image),\n            \"instance_prompt_ids\": self.instance_prompt_ids,\n            \"crop_area\": crop_area,\n        }\n\n        return example\n\n\nclass SINEDatasetSingleRes(Dataset):\n    def __init__(\n        self,\n        img_path,\n        instance_prompt,\n        tokenizer,\n        size=512,\n    ):\n        super().__init__()\n        self.size = size\n        self.tokenizer = tokenizer\n        self.img_path = Path(img_path)\n        if not self.img_path.exists():\n            raise ValueError(f\"Image {self.img_path} doesn't exists.\")\n        self.num_instance_images = 1\n        self.instance_prompt = instance_prompt\n        self._length = self.num_instance_images\n\n        self.image = Image.open(img_path)\n        if not self.image.mode == \"RGB\":\n            self.image = self.image.convert(\"RGB\")\n\n        self.image_transform = transforms.Compose(\n            [\n                transforms.Resize(\n                    (size, size), interpolation=transforms.InterpolationMode.BICUBIC),\n                transforms.ToTensor(),\n                transforms.Normalize([0.5], [0.5]),\n            ]\n        )\n\n        self.image = self.image_transform(self.image)\n        self.instance_prompt_ids = self.tokenizer(\n            self.instance_prompt,\n            truncation=True,\n            padding=\"max_length\",\n            max_length=self.tokenizer.model_max_length,\n            return_tensors=\"pt\",\n        ).input_ids\n\n    def __len__(self):\n        return self._length * 20\n\n    def __getitem__(self, index):\n        example = {\n            \"instance_images\": self.image,\n            \"instance_prompt_ids\": self.instance_prompt_ids,\n        }\n\n        return example\n\n\ndef collate_fn(examples):\n    input_ids = [example[\"instance_prompt_ids\"] for example in examples]\n    pixel_values = [example[\"instance_images\"] for example in examples]\n\n    # Concat class and instance examples for prior preservation.\n    # We do this to avoid doing two forward passes.\n    \n    pixel_values = torch.stack(pixel_values)\n    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()\n\n    input_ids = torch.cat(input_ids, dim=0)\n    \n    if \"crop_area\" in examples[0]:\n        crop_area = [example[\"crop_area\"] for example in examples]\n        crop_area = torch.stack(crop_area)\n        # crop_area = crop_area.to(memory_format=torch.contiguous_format).float()\n        batch = {\n            \"input_ids\": input_ids,\n            \"pixel_values\": pixel_values,\n            \"crop_area\": crop_area,\n        }\n        return batch\n\n    batch = {\n        \"input_ids\": input_ids,\n        \"pixel_values\": pixel_values,\n    }\n    return batch\n\n\ndef get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):\n    if token is None:\n        token = HfFolder.get_token()\n    if organization is None:\n        username = whoami(token)[\"name\"]\n        return f\"{username}/{model_id}\"\n    else:\n        return f\"{organization}/{model_id}\"\n\n\ndef main(args):\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(\n        total_limit=args.checkpoints_total_limit)\n\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        logging_dir=logging_dir,\n        project_config=accelerator_project_config,\n    )\n\n    # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate\n    # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.\n    # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.\n    if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:\n        raise ValueError(\n            \"Gradient accumulation is not supported when training the text encoder in distributed training. \"\n            \"Please set gradient_accumulation_steps to 1. This feature will be supported in the future.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.push_to_hub:\n            if args.hub_model_id is None:\n                repo_name = get_full_repo_name(\n                    Path(args.output_dir).name, token=args.hub_token)\n            else:\n                repo_name = args.hub_model_id\n            create_repo(repo_name, exist_ok=True, token=args.hub_token)\n            repo = Repository(\n                args.output_dir, clone_from=repo_name, token=args.hub_token)\n\n            with open(os.path.join(args.output_dir, \".gitignore\"), \"w+\") as gitignore:\n                if \"step_*\" not in gitignore:\n                    gitignore.write(\"step_*\\n\")\n                if \"epoch_*\" not in gitignore:\n                    gitignore.write(\"epoch_*\\n\")\n        elif args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load the tokenizer\n    if args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.tokenizer_name, revision=args.revision, use_fast=False)\n    elif args.pretrained_model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            args.pretrained_model_name_or_path,\n            subfolder=\"tokenizer\",\n            revision=args.revision,\n            use_fast=False,\n        )\n\n    # import correct text encoder class\n    text_encoder_cls = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision)\n\n    # Load scheduler and models\n    noise_scheduler = DDPMScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\")\n    text_encoder = text_encoder_cls.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n    )\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n    \n    if args.patch_based_training:\n        unet = UNet2DConditionPatchModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision, low_cpu_mem_usage=False, device_map=None\n        )\n    else:\n        unet = UNet2DConditionModel.from_pretrained(\n            args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n        )\n\n    # `accelerate` 0.16.0 will have better support for customized saving\n    if version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n        def save_model_hook(models, weights, output_dir):\n            for model in models:\n                sub_dir = \"unet\" if type(model) == type(\n                    unet) else \"text_encoder\"\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n\n                # make sure to pop weight so that corresponding model is not saved again\n                weights.pop()\n\n        def load_model_hook(models, input_dir):\n            while len(models) > 0:\n                # pop models so that they are not loaded again\n                model = models.pop()\n\n                if type(model) == type(text_encoder):\n                    # load transformers style into model\n                    load_model = text_encoder_cls.from_pretrained(\n                        input_dir, subfolder=\"text_encoder\")\n                    model.config = load_model.config\n                else:\n                    # load diffusers style into model\n                    load_model = UNet2DConditionModel.from_pretrained(\n                        input_dir, subfolder=\"unet\")\n                    model.register_to_config(**load_model.config)\n\n                model.load_state_dict(load_model.state_dict())\n                del load_model\n\n        accelerator.register_save_state_pre_hook(save_model_hook)\n        accelerator.register_load_state_pre_hook(load_model_hook)\n\n    vae.requires_grad_(False)\n    if not args.train_text_encoder:\n        text_encoder.requires_grad_(False)\n\n    if args.enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            import xformers\n\n            xformers_version = version.parse(xformers.__version__)\n            if xformers_version == version.parse(\"0.0.16\"):\n                logger.warn(\n                    \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n                )\n            unet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\n                \"xformers is not available. Make sure it is installed correctly\")\n\n    if args.gradient_checkpointing:\n        unet.enable_gradient_checkpointing()\n        if args.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n    # Check that all trainable models are in full precision\n    low_precision_error_string = (\n        \"Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n        \" doing mixed precision training. copy of the weights should still be float32.\"\n    )\n\n    if accelerator.unwrap_model(unet).dtype != torch.float32:\n        raise ValueError(\n            f\"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n        )\n\n    if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:\n        raise ValueError(\n            f\"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}.\"\n            f\" {low_precision_error_string}\"\n        )\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32:\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps *\n            args.train_batch_size * accelerator.num_processes\n        )\n\n    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\n    if args.use_8bit_adam:\n        try:\n            import bitsandbytes as bnb\n        except ImportError:\n            raise ImportError(\n                \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n            )\n\n        optimizer_class = bnb.optim.AdamW8bit\n    else:\n        optimizer_class = torch.optim.AdamW\n\n    # Optimizer creation\n    params_to_optimize = (\n        itertools.chain(unet.parameters(), text_encoder.parameters(\n        )) if args.train_text_encoder else unet.parameters()\n    )\n    optimizer = optimizer_class(\n        params_to_optimize,\n        lr=args.learning_rate,\n        betas=(args.adam_beta1, args.adam_beta2),\n        weight_decay=args.adam_weight_decay,\n        eps=args.adam_epsilon,\n    )\n\n    # Dataset and DataLoaders creation:\n    train_dataset = SINEDatasetSingleRes(\n        img_path=args.img_path,\n        instance_prompt=args.instance_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n    ) if not args.patch_based_training else SINEDatasetPatch(\n        img_path=args.img_path,\n        instance_prompt=args.instance_prompt,\n        tokenizer=tokenizer,\n        size=args.resolution,\n        high_res=args.high_res,\n        min_crop_frac=args.min_crop_frac,\n        max_crop_frac=args.max_crop_frac,\n        rec_prob=args.rec_prob,\n        latent_scale=args.latent_scale,\n    )\n\n    train_dataloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=args.train_batch_size,\n        shuffle=True,\n        collate_fn=lambda examples: collate_fn(examples),\n        num_workers=args.dataloader_num_workers,\n    )\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(\n        len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,\n        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    if args.train_text_encoder:\n        unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, text_encoder, optimizer, train_dataloader, lr_scheduler\n        )\n    else:\n        unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n            unet, optimizer, train_dataloader, lr_scheduler\n        )\n\n    # For mixed precision training we cast the text_encoder and vae weights to half-precision\n    # as these models are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n\n    # Move vae and text_encoder to device and cast to weight_dtype\n    vae.to(accelerator.device, dtype=weight_dtype)\n    if not args.train_text_encoder:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(\n        len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(\n        args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"dreambooth\", config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * \\\n        accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(\n        f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(\n        f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(\n        f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = os.path.basename(args.resume_from_checkpoint)\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(os.path.join(args.output_dir, path))\n            global_step = int(path.split(\"-\")[1])\n\n            resume_global_step = global_step * args.gradient_accumulation_steps\n            first_epoch = global_step // num_update_steps_per_epoch\n            resume_step = resume_global_step % (\n                num_update_steps_per_epoch * args.gradient_accumulation_steps)\n\n    # Only show the progress bar once on each machine.\n    progress_bar = tqdm(range(global_step, args.max_train_steps),\n                        disable=not accelerator.is_local_main_process)\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        unet.train()\n        if args.train_text_encoder:\n            text_encoder.train()\n        for step, batch in enumerate(train_dataloader):\n            # Skip steps until we reach the resumed step\n            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n                if step % args.gradient_accumulation_steps == 0:\n                    progress_bar.update(1)\n                continue\n\n            with accelerator.accumulate(unet):\n                # Convert images to latent space\n                latents = vae.encode(batch[\"pixel_values\"].to(\n                    dtype=weight_dtype)).latent_dist.sample()\n                latents = latents * vae.config.scaling_factor\n\n                # Sample noise that we'll add to the latents\n                noise = torch.randn_like(latents)\n                bsz = latents.shape[0]\n                # Sample a random timestep for each image\n                timesteps = torch.randint(\n                    0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                noisy_latents = noise_scheduler.add_noise(\n                    latents, noise, timesteps)\n\n                # Get the text embedding for conditioning\n                encoder_hidden_states = text_encoder(batch[\"input_ids\"])[0]\n\n                # Predict the noise residual\n                model_pred = unet(noisy_latents, timesteps,encoder_hidden_states).sample if not args.patch_based_training else \\\n                    unet(noisy_latents, timesteps,encoder_hidden_states, crop_boxes=batch['crop_area']).sample\n\n                # Get the target for loss depending on the prediction type\n                if noise_scheduler.config.prediction_type == \"epsilon\":\n                    target = noise\n                elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                    target = noise_scheduler.get_velocity(\n                        latents, noise, timesteps)\n                else:\n                    raise ValueError(\n                        f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n                loss = F.mse_loss(model_pred.float(),\n                                  target.float(), reduction=\"mean\")\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        itertools.chain(unet.parameters(),\n                                        text_encoder.parameters())\n                        if args.train_text_encoder\n                        else unet.parameters()\n                    )\n                    accelerator.clip_grad_norm_(\n                        params_to_clip, args.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n\n                if global_step % args.checkpointing_steps == 0:\n                    if accelerator.is_main_process:\n                        save_path = os.path.join(\n                            args.output_dir, f\"checkpoint-{global_step}\")\n                        accelerator.save_state(save_path)\n                        logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss\": loss.detach().item(\n            ), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n    # Create the pipeline using using the trained modules and save it.\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        pipeline = DiffusionPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            unet=accelerator.unwrap_model(unet),\n            text_encoder=accelerator.unwrap_model(text_encoder),\n            revision=args.revision,\n        )\n        pipeline.save_pretrained(args.output_dir)\n\n        if args.push_to_hub:\n            repo.push_to_hub(commit_message=\"End of training\",\n                             blocking=False, auto_lfs_prune=True)\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n"
  },
  {
    "path": "environment.yml",
    "content": "name: SINE\nchannels:\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=5.1=1_gnu\n  - ca-certificates=2022.10.11=h06a4308_0\n  - certifi=2022.9.24=py38h06a4308_0\n  - libedit=3.1.20210910=h7f8727e_0\n  - libffi=3.2.1=hf484d3e_1007\n  - libgcc-ng=11.2.0=h1234567_1\n  - libgomp=11.2.0=h1234567_1\n  - libstdcxx-ng=11.2.0=h1234567_1\n  - ncurses=6.3=h5eee18b_3\n  - openssl=1.1.1s=h7f8727e_0\n  - pip=22.2.2=py38h06a4308_0\n  - python=3.8.0=h0371630_2\n  - readline=7.0=h7b6447c_5\n  - sqlite=3.33.0=h62c20be_0\n  - tk=8.6.12=h1ccaba5_0\n  - wheel=0.37.1=pyhd3eb1b0_0\n  - xz=5.2.8=h5eee18b_0\n  - zlib=1.2.13=h5eee18b_0\n  - pip:\n    - absl-py==1.3.0\n    - accelerate==0.16.0\n    - aiohttp==3.8.3\n    - aiosignal==1.3.1\n    - albumentations==1.1.0\n    - antlr4-python3-runtime==4.8\n    - asttokens==2.2.1\n    - async-timeout==4.0.2\n    - attrs==22.1.0\n    - autopep8==2.0.0\n    - backcall==0.2.0\n    - cachetools==5.2.0\n    - charset-normalizer==2.1.1\n    - debugpy==1.6.4\n    - decorator==5.1.1\n    - einops==0.6.0\n    - entrypoints==0.4\n    - executing==1.2.0\n    - filelock==3.8.2\n    - frozenlist==1.3.3\n    - fsspec==2022.11.0\n    - ftfy==6.1.1\n    - future==0.18.2\n    - google-auth==2.15.0\n    - google-auth-oauthlib==0.4.6\n    - grpcio==1.51.1\n    - huggingface-hub==0.11.1\n    - idna==3.4\n    - imageio==2.14.1\n    - imageio-ffmpeg==0.4.7\n    - importlib-metadata==5.1.0\n    - ipdb==0.13.9\n    - ipykernel==6.17.1\n    - ipython==8.7.0\n    - jedi==0.18.2\n    - jinja2==3.1.2\n    - joblib==1.2.0\n    - jupyter-client==7.4.8\n    - jupyter-core==5.1.0\n    - kornia==0.6.0\n    - markdown==3.4.1\n    - markupsafe==2.1.1\n    - matplotlib-inline==0.1.6\n    - multidict==6.0.3\n    - nest-asyncio==1.5.6\n    - networkx==2.8.8\n    - numpy==1.23.5\n    - oauthlib==3.2.2\n    - omegaconf==2.1.1\n    - opencv-python==4.2.0.34\n    - opencv-python-headless==4.6.0.66\n    - packaging==21.3\n    - pandas==1.5.2\n    - parso==0.8.3\n    - pexpect==4.8.0\n    - pickleshare==0.7.5\n    - pillow==9.3.0\n    - platformdirs==2.6.0\n    - prompt-toolkit==3.0.36\n    - protobuf==3.20.3\n    - psutil==5.9.4\n    - ptyprocess==0.7.0\n    - pudb==2019.2\n    - pure-eval==0.2.2\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.10.0\n    - pydantic==1.10.2\n    - pydeprecate==0.3.1\n    - pygments==2.13.0\n    - pyparsing==3.0.9\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.9\n    - pytz==2022.6\n    - pywavelets==1.4.1\n    - pyyaml==6.0\n    - pyzmq==24.0.1\n    - qudida==0.0.4\n    - regex==2022.10.31\n    - requests==2.28.1\n    - requests-oauthlib==1.3.1\n    - rsa==4.9\n    - scikit-image==0.19.3\n    - scikit-learn==1.1.3\n    - scipy==1.9.3\n    - setuptools==59.5.0\n    - six==1.16.0\n    - stack-data==0.6.2\n    - tensorboard==2.11.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.1\n    - test-tube==0.7.5\n    - threadpoolctl==3.1.0\n    - tifffile==2022.10.10\n    - tokenizers==0.13.2\n    - toml==0.10.2\n    - tomli==2.0.1\n    - torch==1.11.0+cu113\n    - torchmetrics==0.11.0\n    - torchvision==0.12.0+cu113\n    - tornado==6.2\n    - tqdm==4.64.1\n    - traitlets==5.6.0\n    - transformers==4.25.1\n    - typing-extensions==4.5.0\n    - urllib3==1.26.13\n    - urwid==2.1.2\n    - wcwidth==0.2.5\n    - werkzeug==2.2.2\n    - yarl==1.8.2\n    - zipp==3.11.0\n"
  },
  {
    "path": "ldm/data/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/data/base.py",
    "content": "from abc import abstractmethod\nfrom torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset\n\n\nclass Txt2ImgIterableBaseDataset(IterableDataset):\n    '''\n    Define an interface to make the IterableDatasets for text2img data chainable\n    '''\n    def __init__(self, num_records=0, valid_ids=None, size=256):\n        super().__init__()\n        self.num_records = num_records\n        self.valid_ids = valid_ids\n        self.sample_ids = valid_ids\n        self.size = size\n\n        print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')\n\n    def __len__(self):\n        return self.num_records\n\n    @abstractmethod\n    def __iter__(self):\n        pass"
  },
  {
    "path": "ldm/data/personalized.py",
    "content": "import os\nimport numpy as np\nimport re\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom copy import deepcopy\n\nimport random\nimport torch\n\ntraining_templates_smallest = [\n    'photo of a sks {}',\n]\n\nreg_templates_smallest = [\n    'photo of a {}',\n]\n\nimagenet_templates_small = [\n    'a photo of a {}',\n    'a rendering of a {}',\n    'a cropped photo of the {}',\n    'the photo of a {}',\n    'a photo of a clean {}',\n    'a photo of a dirty {}',\n    'a dark photo of the {}',\n    'a photo of my {}',\n    'a photo of the cool {}',\n    'a close-up photo of a {}',\n    'a bright photo of the {}',\n    'a cropped photo of a {}',\n    'a photo of the {}',\n    'a good photo of the {}',\n    'a photo of one {}',\n    'a close-up photo of the {}',\n    'a rendition of the {}',\n    'a photo of the clean {}',\n    'a rendition of a {}',\n    'a photo of a nice {}',\n    'a good photo of a {}',\n    'a photo of the nice {}',\n    'a photo of the small {}',\n    'a photo of the weird {}',\n    'a photo of the large {}',\n    'a photo of a cool {}',\n    'a photo of a small {}',\n    'an illustration of a {}',\n    'a rendering of a {}',\n    'a cropped photo of the {}',\n    'the photo of a {}',\n    'an illustration of a clean {}',\n    'an illustration of a dirty {}',\n    'a dark photo of the {}',\n    'an illustration of my {}',\n    'an illustration of the cool {}',\n    'a close-up photo of a {}',\n    'a bright photo of the {}',\n    'a cropped photo of a {}',\n    'an illustration of the {}',\n    'a good photo of the {}',\n    'an illustration of one {}',\n    'a close-up photo of the {}',\n    'a rendition of the {}',\n    'an illustration of the clean {}',\n    'a rendition of a {}',\n    'an illustration of a nice {}',\n    'a good photo of a {}',\n    'an illustration of the nice {}',\n    'an illustration of the small {}',\n    'an illustration of the weird {}',\n    'an illustration of the large {}',\n    'an illustration of a cool {}',\n    'an illustration of a small {}',\n    'a depiction of a {}',\n    'a rendering of a {}',\n    'a cropped photo of the {}',\n    'the photo of a {}',\n    'a depiction of a clean {}',\n    'a depiction of a dirty {}',\n    'a dark photo of the {}',\n    'a depiction of my {}',\n    'a depiction of the cool {}',\n    'a close-up photo of a {}',\n    'a bright photo of the {}',\n    'a cropped photo of a {}',\n    'a depiction of the {}',\n    'a good photo of the {}',\n    'a depiction of one {}',\n    'a close-up photo of the {}',\n    'a rendition of the {}',\n    'a depiction of the clean {}',\n    'a rendition of a {}',\n    'a depiction of a nice {}',\n    'a good photo of a {}',\n    'a depiction of the nice {}',\n    'a depiction of the small {}',\n    'a depiction of the weird {}',\n    'a depiction of the large {}',\n    'a depiction of a cool {}',\n    'a depiction of a small {}',\n]\n\nimagenet_dual_templates_small = [\n    'a photo of a {} with {}',\n    'a rendering of a {} with {}',\n    'a cropped photo of the {} with {}',\n    'the photo of a {} with {}',\n    'a photo of a clean {} with {}',\n    'a photo of a dirty {} with {}',\n    'a dark photo of the {} with {}',\n    'a photo of my {} with {}',\n    'a photo of the cool {} with {}',\n    'a close-up photo of a {} with {}',\n    'a bright photo of the {} with {}',\n    'a cropped photo of a {} with {}',\n    'a photo of the {} with {}',\n    'a good photo of the {} with {}',\n    'a photo of one {} with {}',\n    'a close-up photo of the {} with {}',\n    'a rendition of the {} with {}',\n    'a photo of the clean {} with {}',\n    'a rendition of a {} with {}',\n    'a photo of a nice {} with {}',\n    'a good photo of a {} with {}',\n    'a photo of the nice {} with {}',\n    'a photo of the small {} with {}',\n    'a photo of the weird {} with {}',\n    'a photo of the large {} with {}',\n    'a photo of a cool {} with {}',\n    'a photo of a small {} with {}',\n]\n\nreg_templates_smallest = [\n    'photo of a {}',\n]\n\nreg_templates_no_class_smallest = [\n    'a photo',\n]\n\nreg_templates_no_class_small = [\n    'a photo',\n    'a rendering',\n    'a cropped photo',\n    'the photo',\n    'a dark photo',\n    'a close-up photo',\n    'a bright photo',\n    'a cropped photo',\n    'a good photo',\n    'a rendition',\n    'an illustration',\n    'a depiction',\n]\n\n\nper_img_token_list = [\n    'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',\n]\n\nclass PersonalizedBase(Dataset):\n    def __init__(self,\n                 data_root,\n                 size=None,\n                 repeats=100,\n                 interpolation=\"bicubic\",\n                 flip_p=0.5,\n                 set=\"train\",\n                 placeholder_token=\"dog\",\n                 per_image_tokens=False,\n                 center_crop=False,\n                 mixing_prob=0.25,\n                 coarse_class_text=None,\n                 reg = False\n                 ):\n\n        self.data_root = data_root\n\n        self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        # self._length = len(self.image_paths)\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images \n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n\n        if per_image_tokens:\n            assert self.num_images < len(per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n        self.reg = reg\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        if self.coarse_class_text:\n            placeholder_string = f\"{self.coarse_class_text} {placeholder_string}\"\n\n        if not self.reg:\n            text = random.choice(training_templates_smallest).format(placeholder_string)\n        else:\n            text = random.choice(reg_templates_smallest).format(placeholder_string)\n            \n        example[\"caption\"] = text\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n        \n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            h, w, = img.shape[0], img.shape[1]\n            img = img[(h - crop) // 2:(h + crop) // 2,\n                (w - crop) // 2:(w + crop) // 2]\n\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example[\"image\"] = (image / 127.5 - 1.0).astype(np.float32)\n        return example\n\ndef crop_image(img, size=512, cropping='random', crop_scale=[1, 1]):\n    if cropping in ['random', 'random_long_edge']:\n        h, w, = img.shape[0], img.shape[1]\n        crop = min(h, w)\n        crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)\n        offset_h = np.random.randint(0, h - crop + 1)\n        offset_w = np.random.randint(0, w - crop + 1)\n        img = img[offset_h:offset_h + crop, offset_w:offset_w + crop]\n    elif cropping in ['center', 'center_long_edge']:\n        h, w, = img.shape[0], img.shape[1]\n        crop = min(h, w)\n        crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)\n        img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2]\n    elif cropping == 'crop':\n        h, w, = img.shape[0], img.shape[1]\n        crop = min(h, w, size)\n        crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)\n        offset_h = np.random.randint(0, h - crop + 1)\n        offset_w = np.random.randint(0, w - crop + 1)\n        img = img[offset_h:offset_h + crop, offset_w:offset_w + crop]\n    else:\n        raise NotImplementedError\n    return img\n\n\nclass PersonalizedMulti(Dataset):\n    def __init__(\n        self,\n        data_root=\"\",\n        size=None,\n        repeats=100,\n        interpolation=\"lanczos\",\n        flip_p=0.5,\n        which_set=\"train\",\n        placeholder_token=\"sks\",\n        per_image_tokens=False,\n        cropping='random_long_edge',\n        crop_scale=[1, 1],\n        mixing_prob=0.25,\n        coarse_class_text=None,\n        reg=False,\n        use_small_template=False,\n        delimiters = \",|:|;\",\n        **kwargs,\n    ):\n        # NOTE: split str to list\n        \n        data_root = re.split(delimiters, data_root)\n        placeholder_token = re.split(delimiters, placeholder_token)\n        # import ipdb\n        # ipdb.set_trace()\n        if coarse_class_text:\n            coarse_class_text = re.split(delimiters, coarse_class_text) \n            coarse_class_text = [(None if (s in ['none', 'None', 'null', 'Null']) else s) for s in coarse_class_text]\n        else:\n            coarse_class_text = [None] * len(data_root)\n        assert len(placeholder_token) == len(data_root) == len(coarse_class_text)\n        self.keys = placeholder_token\n        self.data_root = {k: v for k, v in zip(self.keys, data_root)}\n        self.placeholder_token = {k: v for k, v in zip(self.keys, placeholder_token)}\n        self.coarse_class_text = {k: v for k, v in zip(self.keys, coarse_class_text)}\n        \n        self.image_paths = {k: [os.path.join(self.data_root[k], file_path) for file_path in os.listdir(self.data_root[k])] if os.path.isdir(self.data_root[k])else [self.data_root[k]] for k in self.keys}\n\n        self.num_images = {k: len(self.image_paths[k]) for k in self.keys}\n        self._length = max([self.num_images[k] for k in self.keys])\n        if which_set == \"train\":\n            self._length = self._length * repeats\n\n        self.reg = reg\n        # self.per_image_tokens = per_image_tokens\n        self.cropping = cropping\n        self.crop_scale = crop_scale\n        self.mixing_prob = mixing_prob\n        \n        self.use_small_template = use_small_template\n        self.templates = {\n            k: self.setup_templates(\n                placeholder_token=self.placeholder_token[k],\n                coarse_class_text=self.coarse_class_text[k],\n                reg=reg, use_small_template=use_small_template,\n            ) for k in self.keys\n        }\n\n        self.size = size\n        self.interpolation = {\n            \"linear\": PIL.Image.LINEAR,\n            \"bilinear\": PIL.Image.BILINEAR,\n            \"bicubic\": PIL.Image.BICUBIC,\n            \"lanczos\": PIL.Image.LANCZOS,\n        }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n\n    def setup_templates(self, placeholder_token='sks', coarse_class_text='dog', reg=False, use_small_template=False):\n        if reg:  # NOTE: reg dataset\n            if coarse_class_text:\n                placeholder_string = f\"{coarse_class_text}\"\n                templates = imagenet_templates_small if use_small_template else reg_templates_smallest\n                templates = [t.format(placeholder_string) for t in templates]\n            else:\n                templates = reg_templates_no_class_small if use_small_template else reg_templates_no_class_smallest\n        else:  # NOTE: train dataset\n            if coarse_class_text:\n                placeholder_string = f\"{placeholder_token} {coarse_class_text}\"\n            else:\n                placeholder_string = f\"{placeholder_token}\"\n            templates = imagenet_templates_small if use_small_template else reg_templates_smallest\n            templates = [t.format(placeholder_string) for t in templates]\n        return templates\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        key = random.choice(self.keys)\n\n        example = {}\n\n        image = Image.open(self.image_paths[key][i % self.num_images[key]])\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n        img = crop_image(img, size=self.size, cropping=self.cropping, crop_scale=self.crop_scale)\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize((self.size, self.size), resample=self.interpolation)\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        text = random.choice(self.templates[key])\n\n        example[\"caption\"] = text.rstrip()\n        example[\"image\"] = (image / 127.5 - 1.0).astype(np.float32)\n        example[\"label\"] = [self.keys.index(key)]  \n        # import ipdb\n        # ipdb.set_trace()\n        return example\n\n    \nclass SinImageDataset(PersonalizedBase):\n    def __init__(\n        self,\n        data_root,\n        size=None,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"dog\",\n        per_image_tokens=False,\n        center_crop=False,\n        mixing_prob=0.25,\n        coarse_class_text=None,\n        reg = False\n    ):\n        self.data_root = data_root\n        assert os.path.isfile(self.data_root), f\"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}.\"\n\n        self.image_paths = [self.data_root]*100\n\n        # self._length = len(self.image_paths)\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images \n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n\n        if per_image_tokens:\n            assert self.num_images < len(per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n        self.reg = reg\n        \n\nclass SinImageHighResDataset(Dataset):\n    def __init__(self,\n        data_root,\n        size=512,\n        high_resolution=1024,\n        latent_scale=8,\n        min_crop_frac=0.5,\n        max_crop_frac=1.0,\n        rec_prob=0.0,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.,\n        set=\"train\",\n        placeholder_token=\"dog\",\n        per_image_tokens=False,\n\n        mixing_prob=0.25,\n        coarse_class_text=None):\n        self.data_root = data_root\n        assert os.path.isfile(self.data_root), f\"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}.\"\n\n        self.num_images = 100\n        self._length = self.num_images \n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n\n        if per_image_tokens:\n            assert self.num_images < len(per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.high_resolution = high_resolution\n        self.min_crop_frac = min_crop_frac\n        self.max_crop_frac = max_crop_frac\n        self.rec_prob = rec_prob\n        self.latent_scale = latent_scale\n        \n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n        \n        image = Image.open(self.data_root)\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n            \n        self.image = image.resize((self.high_resolution, self.high_resolution), resample=self.interpolation)\n        \n    def __len__(self):\n        return self._length\n    \n    def _random_crop(self, pil_image):\n        patch_size_y = int(\n            (self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        patch_size_x = int(\n            (self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        crop_y = random.randrange((self.high_resolution//self.latent_scale) - patch_size_y + 1)\n        crop_x = random.randrange((self.high_resolution//self.latent_scale) - patch_size_x + 1)\n        return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(\n            (self.size, self.size),\n            resample=self.interpolation), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x\n    \n    def __getitem__(self, i):\n        example = {}\n        image = deepcopy(self.image)\n\n\n        placeholder_string = self.placeholder_token\n        if self.coarse_class_text:\n            placeholder_string = f\"{self.coarse_class_text} {placeholder_string}\"\n\n\n        text = random.choice(training_templates_smallest).format(placeholder_string)\n\n            \n        example[\"caption\"] = text\n\n        if random.random() < self.rec_prob:\n            image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)\n            crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])\n        else:\n            image = image.resize((self.size, self.size), resample=self.interpolation)\n            crop_area = torch.tensor([0, 0, self.high_resolution//self.latent_scale, self.high_resolution//self.latent_scale])\n            \n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example[\"image\"] = (image / 127.5 - 1.0).astype(np.float32)\n        example['crop_boxes'] = crop_area\n        return example"
  },
  {
    "path": "ldm/data/personalized_painting.py",
    "content": "import os\nimport numpy as np\nimport re\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision import transforms\nfrom copy import deepcopy\n\nimport random\nimport torch\n\ntraining_templates_smallest = [\n    'painting of a sks {}',\n]\nstyle_templates_smallest = [\n    '{} in the sks style',\n]\n\nreg_templates_smallest = [\n    'painting of a {}',\n]\n\n\nreg_templates_smallest = [\n    'photo of a {}',\n]\n\nreg_templates_no_class_smallest = [\n    'a photo',\n]\n\nreg_templates_no_class_small = [\n    'a photo',\n    'a rendering',\n    'a cropped photo',\n    'the photo',\n    'a dark photo',\n    'a close-up photo',\n    'a bright photo',\n    'a cropped photo',\n    'a good photo',\n    'a rendition',\n    'an illustration',\n    'a depiction',\n]\n\n\nper_img_token_list = [\n    'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',\n]\n\n\nclass PersonalizedBase(Dataset):\n    def __init__(self,\n                 data_root,\n                 size=None,\n                 repeats=100,\n                 interpolation=\"bicubic\",\n                 flip_p=0.5,\n                 set=\"train\",\n                 placeholder_token=\"dog\",\n                 per_image_tokens=False,\n                 center_crop=False,\n                 mixing_prob=0.25,\n                 coarse_class_text=None,\n                 reg=False,\n                 learn_style=False,\n                 ):\n\n        self.data_root = data_root\n\n        self.image_paths = [os.path.join(\n            self.data_root, file_path) for file_path in os.listdir(self.data_root)]\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n        self.learn_style = learn_style\n\n        if per_image_tokens:\n            assert self.num_images < len(\n                per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n        self.reg = reg\n\n    def __len__(self):\n        return self._length\n\n    def __getitem__(self, i):\n        example = {}\n        image = Image.open(self.image_paths[i % self.num_images])\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        placeholder_string = self.placeholder_token\n        if self.coarse_class_text:\n            placeholder_string = f\"{self.coarse_class_text} {placeholder_string}\"\n\n        if not self.reg:\n            if not self.learn_style:\n                text = random.choice(training_templates_smallest).format(\n                    placeholder_string)\n            else:\n                text = random.choice(style_templates_smallest).format(\n                    placeholder_string)\n        else:\n            text = random.choice(reg_templates_smallest).format(\n                placeholder_string)\n\n        example[\"caption\"] = text\n\n        # default to score-sde preprocessing\n        img = np.array(image).astype(np.uint8)\n\n        if self.center_crop:\n            crop = min(img.shape[0], img.shape[1])\n            h, w, = img.shape[0], img.shape[1]\n            img = img[(h - crop) // 2:(h + crop) // 2,\n                      (w - crop) // 2:(w + crop) // 2]\n\n        image = Image.fromarray(img)\n        if self.size is not None:\n            image = image.resize((self.size, self.size),\n                                 resample=self.interpolation)\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example[\"image\"] = (image / 127.5 - 1.0).astype(np.float32)\n        return example\n\n\nclass SinImageDataset(PersonalizedBase):\n    def __init__(\n        self,\n        data_root,\n        size=None,\n        repeats=100,\n        interpolation=\"bicubic\",\n        flip_p=0.5,\n        set=\"train\",\n        placeholder_token=\"dog\",\n        per_image_tokens=False,\n        center_crop=False,\n        mixing_prob=0.25,\n        coarse_class_text=None,\n        reg=False,\n        learn_style=False,\n    ):\n        self.data_root = data_root\n        assert os.path.isfile(\n            self.data_root), f\"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}.\"\n\n        self.image_paths = [self.data_root]*100\n\n        self.num_images = len(self.image_paths)\n        self._length = self.num_images\n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.center_crop = center_crop\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n        self.learn_style = learn_style\n\n        if per_image_tokens:\n            assert self.num_images < len(\n                per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n        self.reg = reg\n\n\n\nclass SinImageHighResDataset(Dataset):\n    def __init__(self,\n                 data_root,\n                 size=512,\n                 high_resolution=1024,\n                 latent_scale=8,\n                 min_crop_frac=0.5,\n                 max_crop_frac=1.0,\n                 rec_prob=0.0,\n                 repeats=100,\n                 interpolation=\"bicubic\",\n                 flip_p=0.,\n                 set=\"train\",\n                 placeholder_token=\"dog\",\n                 per_image_tokens=False,\n\n                 mixing_prob=0.25,\n                 coarse_class_text=None):\n        self.data_root = data_root\n        assert os.path.isfile(\n            self.data_root), f\"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}.\"\n\n        self.num_images = 100\n        self._length = self.num_images\n\n        self.placeholder_token = placeholder_token\n\n        self.per_image_tokens = per_image_tokens\n        self.mixing_prob = mixing_prob\n\n        self.coarse_class_text = coarse_class_text\n\n        if per_image_tokens:\n            assert self.num_images < len(\n                per_img_token_list), f\"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'.\"\n\n        if set == \"train\":\n            self._length = self.num_images * repeats\n\n        self.size = size\n        self.high_resolution = high_resolution\n        self.min_crop_frac = min_crop_frac\n        self.max_crop_frac = max_crop_frac\n        self.rec_prob = rec_prob\n        self.latent_scale = latent_scale\n\n        self.interpolation = {\"linear\": PIL.Image.LINEAR,\n                              \"bilinear\": PIL.Image.BILINEAR,\n                              \"bicubic\": PIL.Image.BICUBIC,\n                              \"lanczos\": PIL.Image.LANCZOS,\n                              }[interpolation]\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\n\n        image = Image.open(self.data_root)\n\n        if not image.mode == \"RGB\":\n            image = image.convert(\"RGB\")\n\n        self.image = image.resize(\n            (self.high_resolution, self.high_resolution), resample=self.interpolation)\n\n    def __len__(self):\n        return self._length\n\n    def _random_crop(self, pil_image):\n        patch_size_y = int(\n            (self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        patch_size_x = int(\n            (self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))\n        crop_y = random.randrange(\n            (self.high_resolution//self.latent_scale) - patch_size_y + 1)\n        crop_x = random.randrange(\n            (self.high_resolution//self.latent_scale) - patch_size_x + 1)\n        return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(\n            (self.size, self.size),\n            resample=self.interpolation), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x\n\n    def __getitem__(self, i):\n        example = {}\n        image = deepcopy(self.image)\n\n        placeholder_string = self.placeholder_token\n        if self.coarse_class_text:\n            placeholder_string = f\"{self.coarse_class_text} {placeholder_string}\"\n\n        text = random.choice(training_templates_smallest).format(\n            placeholder_string)\n\n        example[\"caption\"] = text\n\n        if random.random() < self.rec_prob:\n            image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)\n            crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])\n        else:\n            image = image.resize((self.size, self.size),\n                                 resample=self.interpolation)\n            crop_area = torch.tensor(\n                [0, 0, self.high_resolution//self.latent_scale, self.high_resolution//self.latent_scale])\n\n        image = self.flip(image)\n        image = np.array(image).astype(np.uint8)\n        example[\"image\"] = (image / 127.5 - 1.0).astype(np.float32)\n        example['crop_boxes'] = crop_area\n        return example\n"
  },
  {
    "path": "ldm/lr_scheduler.py",
    "content": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n    def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):\n        self.lr_warm_up_steps = warm_up_steps\n        self.lr_start = lr_start\n        self.lr_min = lr_min\n        self.lr_max = lr_max\n        self.lr_max_decay_steps = max_decay_steps\n        self.last_lr = 0.\n        self.verbosity_interval = verbosity_interval\n\n    def schedule(self, n, **kwargs):\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_lr}\")\n        if n < self.lr_warm_up_steps:\n            lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start\n            self.last_lr = lr\n            return lr\n        else:\n            t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)\n            t = min(t, 1.0)\n            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (\n                    1 + np.cos(t * np.pi))\n            self.last_lr = lr\n            return lr\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n,**kwargs)\n\n\nclass LambdaWarmUpCosineScheduler2:\n    \"\"\"\n    supports repeated iterations, configurable via lists\n    note: use with a base_lr of 1.0.\n    \"\"\"\n    def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):\n        assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)\n        self.lr_warm_up_steps = warm_up_steps\n        self.f_start = f_start\n        self.f_min = f_min\n        self.f_max = f_max\n        self.cycle_lengths = cycle_lengths\n        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))\n        self.last_f = 0.\n        self.verbosity_interval = verbosity_interval\n\n    def find_in_interval(self, n):\n        interval = 0\n        for cl in self.cum_cycles[1:]:\n            if n <= cl:\n                return interval\n            interval += 1\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                                                       f\"current cycle {cycle}\")\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])\n            t = min(t, 1.0)\n            f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (\n                    1 + np.cos(t * np.pi))\n            self.last_f = f\n            return f\n\n    def __call__(self, n, **kwargs):\n        return self.schedule(n, **kwargs)\n\n\nclass LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):\n\n    def schedule(self, n, **kwargs):\n        cycle = self.find_in_interval(n)\n        n = n - self.cum_cycles[cycle]\n        if self.verbosity_interval > 0:\n            if n % self.verbosity_interval == 0: print(f\"current step: {n}, recent lr-multiplier: {self.last_f}, \"\n                                                       f\"current cycle {cycle}\")\n\n        if n < self.lr_warm_up_steps[cycle]:\n            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]\n            self.last_f = f\n            return f\n        else:\n            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])\n            self.last_f = f\n            return f\n\n"
  },
  {
    "path": "ldm/modules/attention.py",
    "content": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\n\nfrom ldm.modules.diffusionmodules.util import checkpoint\n\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return{el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)\n        k = k.softmax(dim=-1)  \n        context = torch.einsum('bhdn,bhen->bhde', k, v)\n        out = torch.einsum('bhde,bhdn->bhen', context, q)\n        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)\n        return self.to_out(out)\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = rearrange(q, 'b c h w -> b (h w) c')\n        k = rearrange(k, 'b c h w -> b c (h w)')\n        w_ = torch.einsum('bij,bjk->bik', q, k)\n\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, 'b c h w -> b c (h w)')\n        w_ = rearrange(w_, 'b i j -> b j i')\n        h_ = torch.einsum('bij,bjk->bik', v, w_)\n        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim),\n            nn.Dropout(dropout)\n        )\n\n    def forward(self, x, context=None, mask=None):\n        h = self.heads\n\n        q = self.to_q(x)\n        context = default(context, x)\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))\n\n        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale\n\n        if exists(mask):\n            mask = rearrange(mask, 'b ... -> b (...)')\n            max_neg_value = -torch.finfo(sim.dtype).max\n            mask = repeat(mask, 'b j -> (b h) () j', h=h)\n            sim.masked_fill_(~mask, max_neg_value)\n\n        # attention, what we cannot get enough of\n        attn = sim.softmax(dim=-1)\n\n        out = einsum('b i j, b j d -> b i d', attn, v)\n        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):\n        super().__init__()\n        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,\n                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n\n    def forward(self, x, context=None):\n        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)\n\n    def _forward(self, x, context=None):\n        x = self.attn1(self.norm1(x)) + x\n        x = self.attn2(self.norm2(x), context=context) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    \"\"\"\n    def __init__(self, in_channels, n_heads, d_head,\n                 depth=1, dropout=0., context_dim=None):\n        super().__init__()\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n\n        self.proj_in = nn.Conv2d(in_channels,\n                                 inner_dim,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n\n        self.transformer_blocks = nn.ModuleList(\n            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)\n                for d in range(depth)]\n        )\n\n        self.proj_out = zero_module(nn.Conv2d(inner_dim,\n                                              in_channels,\n                                              kernel_size=1,\n                                              stride=1,\n                                              padding=0))\n\n    def forward(self, x, context=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n        x = rearrange(x, 'b c h w -> b (h w) c')\n        for block in self.transformer_blocks:\n            x = block(x, context=context)\n        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)\n        x = self.proj_out(x)\n        return x + x_in"
  },
  {
    "path": "ldm/modules/diffusionmodules/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/diffusionmodules/model.py",
    "content": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\n\nfrom ldm.util import instantiate_from_config\nfrom ldm.modules.attention import LinearAttention\n\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models:\n    From Fairseq.\n    Build sinusoidal embeddings.\n    This matches the implementation in tensor2tensor, but differs slightly\n    from the description in Section 3.5 of \"Attention Is All You Need\".\n    \"\"\"\n    assert len(timesteps.shape) == 1\n\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n    emb = emb.to(device=timesteps.device)\n    emb = timesteps.float()[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0,1,0,0))\n    return emb\n\n\ndef nonlinearity(x):\n    # swish\n    return x*torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=3,\n                                        stride=2,\n                                        padding=0)\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0,1,0,1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,\n                 dropout, temb_channels=512):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(in_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels,\n                                             out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(out_channels,\n                                     out_channels,\n                                     kernel_size=3,\n                                     stride=1,\n                                     padding=1)\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(in_channels,\n                                                     out_channels,\n                                                     kernel_size=3,\n                                                     stride=1,\n                                                     padding=1)\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(in_channels,\n                                                    out_channels,\n                                                    kernel_size=1,\n                                                    stride=1,\n                                                    padding=0)\n\n    def forward(self, x, temb):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]\n\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x+h\n\n\nclass LinAttnBlock(LinearAttention):\n    \"\"\"to match AttnBlock usage\"\"\"\n    def __init__(self, in_channels):\n        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.k = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.v = torch.nn.Conv2d(in_channels,\n                                 in_channels,\n                                 kernel_size=1,\n                                 stride=1,\n                                 padding=0)\n        self.proj_out = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=1,\n                                        stride=1,\n                                        padding=0)\n\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b,c,h,w = q.shape\n        q = q.reshape(b,c,h*w)\n        q = q.permute(0,2,1)   # b,hw,c\n        k = k.reshape(b,c,h*w) # b,c,hw\n        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n        w_ = w_ * (int(c)**(-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = v.reshape(b,c,h*w)\n        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)\n        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n        h_ = h_.reshape(b,c,h,w)\n\n        h_ = self.proj_out(h_)\n\n        return x+h_\n\n\ndef make_attn(in_channels, attn_type=\"vanilla\"):\n    assert attn_type in [\"vanilla\", \"linear\", \"none\"], f'attn_type {attn_type} unknown'\n    print(f\"making attention of type '{attn_type}' with {in_channels} in_channels\")\n    if attn_type == \"vanilla\":\n        return AttnBlock(in_channels)\n    elif attn_type == \"none\":\n        return nn.Identity(in_channels)\n    else:\n        return LinAttnBlock(in_channels)\n\n\nclass Model(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, use_timestep=True, use_linear_attn=False, attn_type=\"vanilla\"):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = self.ch*4\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        self.use_timestep = use_timestep\n        if self.use_timestep:\n            # timestep embedding\n            self.temb = nn.Module()\n            self.temb.dense = nn.ModuleList([\n                torch.nn.Linear(self.ch,\n                                self.temb_ch),\n                torch.nn.Linear(self.temb_ch,\n                                self.temb_ch),\n            ])\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            skip_in = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                if i_block == self.num_res_blocks:\n                    skip_in = ch*in_ch_mult[i_level]\n                block.append(ResnetBlock(in_channels=block_in+skip_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x, t=None, context=None):\n        #assert x.shape[2] == x.shape[3] == self.resolution\n        if context is not None:\n            # assume aligned context, cat along channel axis\n            x = torch.cat((x, context), dim=1)\n        if self.use_timestep:\n            # timestep embedding\n            assert t is not None\n            temb = get_timestep_embedding(t, self.ch)\n            temb = self.temb.dense[0](temb)\n            temb = nonlinearity(temb)\n            temb = self.temb.dense[1](temb)\n        else:\n            temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](\n                    torch.cat([h, hs.pop()], dim=1), temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n    def get_last_layer(self):\n        return self.conv_out.weight\n\n\nclass Encoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, double_z=True, use_linear_attn=False, attn_type=\"vanilla\",\n                 **ignore_kwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels,\n                                       self.ch,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        curr_res = resolution\n        in_ch_mult = (1,)+tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions-1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        2*z_channels if double_z else z_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # timestep embedding\n        temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,\n                 resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,\n                 attn_type=\"vanilla\", **ignorekwargs):\n        super().__init__()\n        if use_linear_attn: attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.tanh_out = tanh_out\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,)+tuple(ch_mult)\n        block_in = ch*ch_mult[self.num_resolutions-1]\n        curr_res = resolution // 2**(self.num_resolutions-1)\n        self.z_shape = (1,z_channels,curr_res,curr_res)\n        print(\"Working with z of shape {} = {} dimensions.\".format(\n            self.z_shape, np.prod(self.z_shape)))\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels,\n                                       block_in,\n                                       kernel_size=3,\n                                       stride=1,\n                                       padding=1)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in,\n                                       out_channels=block_in,\n                                       temb_channels=self.temb_ch,\n                                       dropout=dropout)\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch*ch_mult[i_level]\n            for i_block in range(self.num_res_blocks+1):\n                block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_ch,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, z):\n        #assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks+1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        if self.tanh_out:\n            h = torch.tanh(h)\n        return h\n\n\nclass SimpleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, *args, **kwargs):\n        super().__init__()\n        self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),\n                                     ResnetBlock(in_channels=in_channels,\n                                                 out_channels=2 * in_channels,\n                                                 temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=2 * in_channels,\n                                                out_channels=4 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     ResnetBlock(in_channels=4 * in_channels,\n                                                out_channels=2 * in_channels,\n                                                temb_channels=0, dropout=0.0),\n                                     nn.Conv2d(2*in_channels, in_channels, 1),\n                                     Upsample(in_channels, with_conv=True)])\n        # end\n        self.norm_out = Normalize(in_channels)\n        self.conv_out = torch.nn.Conv2d(in_channels,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        for i, layer in enumerate(self.model):\n            if i in [1,2,3]:\n                x = layer(x, None)\n            else:\n                x = layer(x)\n\n        h = self.norm_out(x)\n        h = nonlinearity(h)\n        x = self.conv_out(h)\n        return x\n\n\nclass UpsampleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,\n                 ch_mult=(2,2), dropout=0.0):\n        super().__init__()\n        # upsampling\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        block_in = in_channels\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.res_blocks = nn.ModuleList()\n        self.upsample_blocks = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            res_block = []\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                res_block.append(ResnetBlock(in_channels=block_in,\n                                         out_channels=block_out,\n                                         temb_channels=self.temb_ch,\n                                         dropout=dropout))\n                block_in = block_out\n            self.res_blocks.append(nn.ModuleList(res_block))\n            if i_level != self.num_resolutions - 1:\n                self.upsample_blocks.append(Upsample(block_in, True))\n                curr_res = curr_res * 2\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(block_in,\n                                        out_channels,\n                                        kernel_size=3,\n                                        stride=1,\n                                        padding=1)\n\n    def forward(self, x):\n        # upsampling\n        h = x\n        for k, i_level in enumerate(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.res_blocks[i_level][i_block](h, None)\n            if i_level != self.num_resolutions - 1:\n                h = self.upsample_blocks[k](h)\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass LatentRescaler(nn.Module):\n    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):\n        super().__init__()\n        # residual block, interpolate, residual block\n        self.factor = factor\n        self.conv_in = nn.Conv2d(in_channels,\n                                 mid_channels,\n                                 kernel_size=3,\n                                 stride=1,\n                                 padding=1)\n        self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n        self.attn = AttnBlock(mid_channels)\n        self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,\n                                                     out_channels=mid_channels,\n                                                     temb_channels=0,\n                                                     dropout=0.0) for _ in range(depth)])\n\n        self.conv_out = nn.Conv2d(mid_channels,\n                                  out_channels,\n                                  kernel_size=1,\n                                  )\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        for block in self.res_block1:\n            x = block(x, None)\n        x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))\n        x = self.attn(x)\n        for block in self.res_block2:\n            x = block(x, None)\n        x = self.conv_out(x)\n        return x\n\n\nclass MergedRescaleEncoder(nn.Module):\n    def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,\n                 attn_resolutions, dropout=0.0, resamp_with_conv=True,\n                 ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        intermediate_chn = ch * ch_mult[-1]\n        self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,\n                               z_channels=intermediate_chn, double_z=False, resolution=resolution,\n                               attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,\n                               out_ch=None)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,\n                                       mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.encoder(x)\n        x = self.rescaler(x)\n        return x\n\n\nclass MergedRescaleDecoder(nn.Module):\n    def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),\n                 dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):\n        super().__init__()\n        tmp_chn = z_channels*ch_mult[-1]\n        self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,\n                               resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,\n                               ch_mult=ch_mult, resolution=resolution, ch=ch)\n        self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,\n                                       out_channels=tmp_chn, depth=rescale_module_depth)\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Upsampler(nn.Module):\n    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):\n        super().__init__()\n        assert out_size >= in_size\n        num_blocks = int(np.log2(out_size//in_size))+1\n        factor_up = 1.+ (out_size % in_size)\n        print(f\"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}\")\n        self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,\n                                       out_channels=in_channels)\n        self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,\n                               attn_resolutions=[], in_channels=None, ch=in_channels,\n                               ch_mult=[ch_mult for _ in range(num_blocks)])\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Resize(nn.Module):\n    def __init__(self, in_channels=None, learned=False, mode=\"bilinear\"):\n        super().__init__()\n        self.with_conv = learned\n        self.mode = mode\n        if self.with_conv:\n            print(f\"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode\")\n            raise NotImplementedError()\n            assert in_channels is not None\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(in_channels,\n                                        in_channels,\n                                        kernel_size=4,\n                                        stride=2,\n                                        padding=1)\n\n    def forward(self, x, scale_factor=1.0):\n        if scale_factor==1.0:\n            return x\n        else:\n            x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)\n        return x\n\nclass FirstStagePostProcessor(nn.Module):\n\n    def __init__(self, ch_mult:list, in_channels,\n                 pretrained_model:nn.Module=None,\n                 reshape=False,\n                 n_channels=None,\n                 dropout=0.,\n                 pretrained_config=None):\n        super().__init__()\n        if pretrained_config is None:\n            assert pretrained_model is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.pretrained_model = pretrained_model\n        else:\n            assert pretrained_config is not None, 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.instantiate_pretrained(pretrained_config)\n\n        self.do_reshape = reshape\n\n        if n_channels is None:\n            n_channels = self.pretrained_model.encoder.ch\n\n        self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)\n        self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,\n                            stride=1,padding=1)\n\n        blocks = []\n        downs = []\n        ch_in = n_channels\n        for m in ch_mult:\n            blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))\n            ch_in = m * n_channels\n            downs.append(Downsample(ch_in, with_conv=False))\n\n        self.model = nn.ModuleList(blocks)\n        self.downsampler = nn.ModuleList(downs)\n\n\n    def instantiate_pretrained(self, config):\n        model = instantiate_from_config(config)\n        self.pretrained_model = model.eval()\n        # self.pretrained_model.train = False\n        for param in self.pretrained_model.parameters():\n            param.requires_grad = False\n\n\n    @torch.no_grad()\n    def encode_with_pretrained(self,x):\n        c = self.pretrained_model.encode(x)\n        if isinstance(c, DiagonalGaussianDistribution):\n            c = c.mode()\n        return  c\n\n    def forward(self,x):\n        z_fs = self.encode_with_pretrained(x)\n        z = self.proj_norm(z_fs)\n        z = self.proj(z)\n        z = nonlinearity(z)\n\n        for submodel, downmodel in zip(self.model,self.downsampler):\n            z = submodel(z,temb=None)\n            z = downmodel(z)\n\n        if self.do_reshape:\n            z = rearrange(z,'b c h w -> b (h w) c')\n        return z\n\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/openaimodel.py",
    "content": "from abc import abstractmethod\nfrom decimal import HAVE_THREADS\nfrom email.base64mime import header_length\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom ldm.modules.diffusionmodules.util import (\n    checkpoint,\n    conv_nd,\n    linear,\n    avg_pool_nd,\n    zero_module,\n    normalization,\n    timestep_embedding,\n)\nfrom ldm.modules.attention import SpatialTransformer\nfrom pydantic import StrictInt, StrictFloat, StrictBool, StrictStr\nimport torch\nfrom ldm.modules.diffusionmodules.positional_encoding import SinusoidalPositionalEmbedding\nfrom einops import rearrange, repeat\n\n\n# dummy replace\ndef convert_module_to_f16(x):\n    pass\n\n\ndef convert_module_to_f32(x):\n    pass\n\n\n# go\nclass AttentionPool2d(nn.Module):\n    \"\"\"\n    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spacial_dim: int,\n        embed_dim: int,\n        num_heads_channels: int,\n        output_dim: int = None,\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)\n        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)\n        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)\n        self.num_heads = embed_dim // num_heads_channels\n        self.attention = QKVAttention(self.num_heads)\n\n    def forward(self, x):\n        b, c, *_spatial = x.shape\n        x = x.reshape(b, c, -1)  # NC(HW)\n        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)\n        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)\n        x = self.qkv_proj(x)\n        x = self.attention(x)\n        x = self.c_proj(x)\n        return x[:, :, 0]\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, context=None):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, SpatialTransformer):\n                x = layer(x, context)\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(dims, self.channels,\n                                self.out_channels, 3, padding=padding)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(\n                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode=\"nearest\"\n            )\n        else:\n            x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass TransposedUpsample(nn.Module):\n    'Learned 2x upsampling without padding'\n\n    def __init__(self, channels, out_channels=None, ks=5):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n\n        self.up = nn.ConvTranspose2d(\n            self.channels, self.out_channels, kernel_size=ks, stride=2)\n\n    def forward(self, x):\n        return self.up(x)\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(\n                dims, self.channels, self.out_channels, 3, stride=stride, padding=padding\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        use_checkpoint=False,\n        up=False,\n        down=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.SiLU(),\n            linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(dims, self.out_channels,\n                        self.out_channels, 3, padding=1)\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 3, padding=1\n            )\n        else:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 1)\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        return checkpoint(\n            self._forward, (x, emb), self.parameters(), self.use_checkpoint\n        )\n\n    def _forward(self, x, emb):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        num_heads=1,\n        num_head_channels=-1,\n        use_checkpoint=False,\n        use_new_attention_order=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n            self.num_heads = channels // num_head_channels\n        self.use_checkpoint = use_checkpoint\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        if use_new_attention_order:\n            # split qkv before split heads\n            self.attention = QKVAttention(self.num_heads)\n        else:\n            # split heads before split qkv\n            self.attention = QKVAttentionLegacy(self.num_heads)\n\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x):\n        # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!\n        return checkpoint(self._forward, (x,), self.parameters(), True)\n        # return pt_checkpoint(self._forward, x)  # pytorch\n\n    def _forward(self, x):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.qkv(self.norm(x))\n        h = self.attention(qkv)\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\ndef count_flops_attn(model, _x, y):\n    \"\"\"\n    A counter for the `thop` package to count the operations in an\n    attention operation.\n    Meant to be used like:\n        macs, params = thop.profile(\n            model,\n            inputs=(inputs, timestamps),\n            custom_ops={QKVAttention: QKVAttention.count_flops},\n        )\n    \"\"\"\n    b, c, *spatial = y[0].shape\n    num_spatial = int(np.prod(spatial))\n    # We perform two matmuls with the same number of ops.\n    # The first computes the weight matrix, the second computes\n    # the combination of the value vectors.\n    matmul_ops = 2 * b * (num_spatial ** 2) * c\n    model.total_ops += th.DoubleTensor([matmul_ops])\n\n\nclass QKVAttentionLegacy(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3,\n                              length).split(ch, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v)\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention and splits in a different order.\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\",\n            (q * scale).view(bs * self.n_heads, ch, length),\n            (k * scale).view(bs * self.n_heads, ch, length),\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight,\n                      v.reshape(bs * self.n_heads, ch, length))\n        return a.reshape(bs, -1, length)\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        num_classes=None,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        use_spatial_transformer=False,    # custom transformer support\n        transformer_depth=1,              # custom transformer support\n        context_dim=None,                 # custom transformer support\n        # custom support for prediction of discrete ids into codebook of first stage vq model\n        n_embed=None,\n        legacy=True,\n    ):\n        super().__init__()\n        if use_spatial_transformer:\n            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'\n\n        if context_dim is not None:\n            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'\n            from omegaconf.listconfig import ListConfig\n            if type(context_dim) == ListConfig:\n                context_dim = list(context_dim)\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'\n\n        if num_head_channels == -1:\n            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.predict_codebook_ids = n_embed is not None\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        #num_heads = 1\n                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=dim_head,\n                            use_new_attention_order=use_new_attention_order,\n                        ) if not use_spatial_transformer else SpatialTransformer(\n                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        if legacy:\n            #num_heads = 1\n            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=dim_head,\n                use_new_attention_order=use_new_attention_order,\n            ) if not use_spatial_transformer else SpatialTransformer(\n                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        #num_heads = 1\n                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads_upsample,\n                            num_head_channels=dim_head,\n                            use_new_attention_order=use_new_attention_order,\n                        ) if not use_spatial_transformer else SpatialTransformer(\n                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim\n                        )\n                    )\n                if level and i == num_res_blocks:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels,\n                        out_channels, 3, padding=1)),\n        )\n        if self.predict_codebook_ids:\n            self.id_predictor = nn.Sequential(\n                normalization(ch),\n                conv_nd(dims, model_channels, n_embed, 1),\n                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits\n            )\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n        self.output_blocks.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n        self.output_blocks.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps=None, context=None, y=None, **kwargs):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n        hs = []\n        t_emb = timestep_embedding(\n            timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb, context)\n            hs.append(h)\n        h = self.middle_block(h, emb, context)\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)\n\n\nclass UNetModelPatch(UNetModel):\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        num_classes=None,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        use_spatial_transformer=False,    # custom transformer support\n        transformer_depth=1,              # custom transformer support\n        context_dim=None,                 # custom transformer support\n        # custom support for prediction of discrete ids into codebook of first stage vq model\n        n_embed=None,\n        legacy=True,\n        padding_idx: StrictInt = 0,\n        init_size: StrictInt = 512,\n        div_half_dim: StrictBool = False,\n        center_shift: StrictInt = 200,\n        interpolation_mode: StrictStr = \"bilinear\",\n    ):\n        super().__init__(image_size,\n                         in_channels=in_channels,\n                         model_channels=model_channels,\n                         out_channels=out_channels,\n                         num_res_blocks=num_res_blocks,\n                         attention_resolutions=attention_resolutions,\n                         dropout=dropout,\n                         channel_mult=channel_mult,\n                         conv_resample=conv_resample,\n                         dims=dims,\n                         num_classes=num_classes,\n                         use_checkpoint=use_checkpoint,\n                         use_fp16=use_fp16,\n                         num_heads=num_heads,\n                         num_head_channels=num_head_channels,\n                         num_heads_upsample=num_heads_upsample,\n                         use_scale_shift_norm=use_scale_shift_norm,\n                         resblock_updown=resblock_updown,\n                         use_new_attention_order=use_new_attention_order,\n                         use_spatial_transformer=use_spatial_transformer,    # custom transformer support\n                         transformer_depth=transformer_depth,              # custom transformer support\n                         context_dim=context_dim,                 # custom transformer support\n                         # custom support for prediction of discrete ids into codebook of first stage vq model\n                         n_embed=n_embed,\n                         legacy=legacy,)\n        assert model_channels % 2 == 0\n        self.head_position_encode = SinusoidalPositionalEmbedding(embedding_dim=model_channels//2,\n                                                                  padding_idx=padding_idx,\n                                                                  init_size=init_size,\n                                                                  div_half_dim=div_half_dim,\n                                                                  center_shift=center_shift)\n        self.init_size = init_size\n        self.interpolation_mode = interpolation_mode\n\n    def forward(self, x, timesteps=None, context=None, y=None, crop_boxes=None, **kwargs):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        assert (y is not None) == (\n            self.num_classes is not None\n        ), \"must specify y if and only if the model is class-conditional\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n        \n        head_grid = self.head_position_encode(torch.ones([x.shape[0], x.shape[1], self.init_size, self.init_size], dtype=self.dtype,\n                       device=x.device))\n        \n        if self.interpolation_mode == 'bilinear':\n            if crop_boxes is not None:\n                \n                head_grid = torch.cat([F.interpolate(hg.unsqueeze(0)[:, :, box[0]: box[2], box[1]: box[3]],\n                                                    (x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)\n                                    for hg, box in\n                                    zip(head_grid, crop_boxes)], dim=0)\n            else:\n                head_grid = F.interpolate(head_grid, (x.shape[2], x.shape[3]), mode='bilinear', align_corners=True)\n        elif self.interpolation_mode == 'nearest':\n            if crop_boxes is not None:\n            \n                head_grid = torch.cat([F.interpolate(hg.unsqueeze(0)[:, :, box[0]: box[2], box[1]: box[3]],\n                                                    (x.shape[2], x.shape[3]), mode='nearest')\n                                    for hg, box in\n                                    zip(head_grid, crop_boxes)], dim=0)\n            else:\n                head_grid = F.interpolate(head_grid, (x.shape[2], x.shape[3]), mode='nearest')\n        else:\n            raise NotImplementedError\n\n        if self.num_classes is not None:\n            assert y.shape == (x.shape[0],)\n            emb = emb + self.label_emb(y)\n        # import ipdb\n        # ipdb.set_trace()\n        h = x.type(self.dtype)\n        for i, module in enumerate(self.input_blocks):\n            h = module(h, emb, context)\n            if i == 0:\n                h += head_grid\n            hs.append(h)\n        h = self.middle_block(h, emb, context)\n        for module in self.output_blocks:\n            h = th.cat([h, hs.pop()], dim=1)\n            h = module(h, emb, context)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)\n\n\nclass EncoderUNetModel(nn.Module):\n    \"\"\"\n    The half UNet model with attention and timestep embedding.\n    For usage, see UNet.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        pool=\"adaptive\",\n        *args,\n        **kwargs\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=num_head_channels,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=num_head_channels,\n                use_new_attention_order=use_new_attention_order,\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n        self.pool = pool\n        if pool == \"adaptive\":\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                nn.AdaptiveAvgPool2d((1, 1)),\n                zero_module(conv_nd(dims, ch, out_channels, 1)),\n                nn.Flatten(),\n            )\n        elif pool == \"attention\":\n            assert num_head_channels != -1\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                AttentionPool2d(\n                    (image_size // ds), ch, num_head_channels, out_channels\n                ),\n            )\n        elif pool == \"spatial\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                nn.ReLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        elif pool == \"spatial_v2\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                normalization(2048),\n                nn.SiLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        else:\n            raise NotImplementedError(f\"Unexpected {pool} pooling\")\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :return: an [N x K] Tensor of outputs.\n        \"\"\"\n        emb = self.time_embed(timestep_embedding(\n            timesteps, self.model_channels))\n\n        results = []\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            if self.pool.startswith(\"spatial\"):\n                results.append(h.type(x.dtype).mean(dim=(2, 3)))\n        h = self.middle_block(h, emb)\n        if self.pool.startswith(\"spatial\"):\n            results.append(h.type(x.dtype).mean(dim=(2, 3)))\n            h = th.cat(results, axis=-1)\n            return self.out(h)\n        else:\n            h = h.type(x.dtype)\n            return self.out(h)\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/positional_encoding.py",
    "content": "#!/usr/bin/env python \n# encoding: utf-8\n# @Time     : 3/31/22 21:42\n# @Author   : Zhixing Zhang\n# @File     : positional_encoding.py\n# @Contact  : zhixing.zhang@rutgers.edu\n# @Desc     :\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\n\nclass SinusoidalPositionalEmbedding(nn.Module):\n    \"\"\"Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d).\n\n    This module is a modified from:\n    https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa\n\n    Based on the original SPE in single dimension, we implement a 2D sinusoidal\n    positional encodding (SPE2d), as introduced in Positional Encoding as\n    Spatial Inductive Bias in GANs, CVPR'2021.\n\n    Args:\n        embedding_dim (int): The number of dimensions for the positional\n            encoding.\n        padding_idx (int | list[int]): The index for the padding contents. The\n            padding positions will obtain an encoding vector filling in zeros.\n        init_size (int, optional): The initial size of the positional buffer.\n            Defaults to 1024.\n        div_half_dim (bool, optional): If true, the embedding will be divided\n            by :math:`d/2`. Otherwise, it will be divided by\n            :math:`(d/2 -1)`. Defaults to False.\n        center_shift (int | None, optional): Shift the center point to some\n            index. Defaults to None.\n    \"\"\"\n\n    def __init__(self,\n                 embedding_dim,\n                 padding_idx,\n                 init_size=1024,\n                 div_half_dim=False,\n                 center_shift=None):\n        super().__init__()\n        self.embedding_dim = embedding_dim\n        self.padding_idx = padding_idx\n        self.div_half_dim = div_half_dim\n        self.center_shift = center_shift\n\n        self.weights = SinusoidalPositionalEmbedding.get_embedding(\n            init_size, embedding_dim, padding_idx, self.div_half_dim)\n\n        self.register_buffer('_float_tensor', torch.FloatTensor(1))\n\n        self.max_positions = int(1e5)\n\n    @staticmethod\n    def get_embedding(num_embeddings,\n                      embedding_dim,\n                      padding_idx=None,\n                      div_half_dim=False):\n        \"\"\"Build sinusoidal embeddings.\n\n        This matches the implementation in tensor2tensor, but differs slightly\n        from the description in Section 3.5 of \"Attention Is All You Need\".\n        \"\"\"\n        assert embedding_dim % 2 == 0, (\n            'In this version, we request '\n            f'embedding_dim divisible by 2 but got {embedding_dim}')\n\n        # there is a little difference from the original paper.\n        half_dim = embedding_dim // 2\n        if not div_half_dim:\n            emb = np.log(10000) / (half_dim - 1)\n        else:\n            emb = np.log(1e4) / half_dim\n        # compute exp(-log10000 / d * i)\n        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)\n        emb = torch.arange(\n            num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)],\n                        dim=1).view(num_embeddings, -1)\n        if padding_idx is not None:\n            emb[padding_idx, :] = 0\n\n        return emb\n\n    def forward(self, input, **kwargs):\n        \"\"\"Input is expected to be of size [bsz x seqlen].\n\n        Returned tensor is expected to be of size  [bsz x seq_len x emb_dim]\n        \"\"\"\n        assert input.dim() == 2 or input.dim(\n        ) == 4, 'Input dimension should be 2 (1D) or 4(2D)'\n\n        if input.dim() == 4:\n            return self.make_grid2d_like(input, **kwargs)\n\n        b, seq_len = input.shape\n        max_pos = self.padding_idx + 1 + seq_len\n\n        if self.weights is None or max_pos > self.weights.size(0):\n            # recompute/expand embedding if needed\n            self.weights = SinusoidalPositionalEmbedding.get_embedding(\n                max_pos, self.embedding_dim, self.padding_idx)\n        self.weights = self.weights.to(self._float_tensor)\n\n        positions = self.make_positions(input, self.padding_idx).to(\n            self._float_tensor.device)\n\n        return self.weights.index_select(0, positions.view(-1)).view(\n            b, seq_len, self.embedding_dim).detach()\n\n    def make_positions(self, input, padding_idx):\n        mask = input.ne(padding_idx).int()\n        return (torch.cumsum(mask, dim=1).type_as(mask) *\n                mask).long() + padding_idx\n\n    def make_grid2d(self, height, width, num_batches=1, center_shift=None):\n        h, w = height, width\n        # if `center_shift` is not given from the outside, use\n        # `self.center_shift`\n        if center_shift is None:\n            center_shift = self.center_shift\n\n        h_shift = 0\n        w_shift = 0\n        # center shift to the input grid\n        if center_shift is not None:\n            # if h/w is even, the left center should be aligned with\n            # center shift\n            if h % 2 == 0:\n                h_left_center = h // 2\n                h_shift = center_shift - h_left_center\n            else:\n                h_center = h // 2 + 1\n                h_shift = center_shift - h_center\n\n            if w % 2 == 0:\n                w_left_center = w // 2\n                w_shift = center_shift - w_left_center\n            else:\n                w_center = w // 2 + 1\n                w_shift = center_shift - w_center\n\n        # Note that the index is started from 1 since zero will be padding idx.\n        # axis -- (b, h or w)\n        x_axis = torch.arange(1, w + 1).unsqueeze(0).repeat(num_batches,\n                                                            1) + w_shift\n        y_axis = torch.arange(1, h + 1).unsqueeze(0).repeat(num_batches,\n                                                            1) + h_shift\n\n        # emb -- (b, emb_dim, h or w)\n        x_emb = self(x_axis).transpose(1, 2)\n        y_emb = self(y_axis).transpose(1, 2)\n\n        # make grid for x/y axis\n        # Note that repeat will copy data. If use learned emb, expand may be\n        # better.\n        x_grid = x_emb.unsqueeze(2).repeat(1, 1, h, 1)\n        y_grid = y_emb.unsqueeze(3).repeat(1, 1, 1, w)\n\n        # cat grid -- (b, 2 x emb_dim, h, w)\n        grid = torch.cat([x_grid, y_grid], dim=1)\n        return grid.detach()\n\n    def make_grid2d_like(self, x, center_shift=None):\n        \"\"\"Input tensor with shape of (b, ..., h, w) Return tensor with shape\n        of (b, 2 x emb_dim, h, w)\n\n        Note that the positional embedding highly depends on the the function,\n        ``make_positions``.\n        \"\"\"\n        h, w = x.shape[-2:]\n\n        grid = self.make_grid2d(h, w, x.size(0), center_shift)\n\n        return grid.to(x)\n\n\nclass CatersianGrid(nn.Module):\n    \"\"\"Catersian Grid for 2d tensor.\n\n    The Catersian Grid is a common-used positional encoding in deep learning.\n    In this implementation, we follow the convention of ``grid_sample`` in\n    PyTorch. In other words, ``[-1, -1]`` denotes the left-top corner while\n    ``[1, 1]`` denotes the right-botton corner.\n    \"\"\"\n\n    def forward(self, x, **kwargs):\n        assert x.dim() == 4\n        return self.make_grid2d_like(x, **kwargs)\n\n    def make_grid2d(self, height, width, num_batches=1, requires_grad=False):\n        h, w = height, width\n        grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))\n        grid_x = 2 * grid_x / max(float(w) - 1., 1.) - 1.\n        grid_y = 2 * grid_y / max(float(h) - 1., 1.) - 1.\n        grid = torch.stack((grid_x, grid_y), 0)\n        grid.requires_grad = requires_grad\n\n        grid = torch.unsqueeze(grid, 0)\n        grid = grid.repeat(num_batches, 1, 1, 1)\n\n        return grid\n\n    def make_grid2d_like(self, x, requires_grad=False):\n        h, w = x.shape[-2:]\n        grid = self.make_grid2d(h, w, x.size(0), requires_grad=requires_grad)\n\n        return grid.to(x)\n"
  },
  {
    "path": "ldm/modules/diffusionmodules/util.py",
    "content": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\n# and\n# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py\n#\n# thanks!\n\n\nimport os\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import repeat\n\nfrom ldm.util import instantiate_from_config\n\n\ndef make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n    if schedule == \"linear\":\n        betas = (\n                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2\n        )\n\n    elif schedule == \"cosine\":\n        timesteps = (\n                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s\n        )\n        alphas = timesteps / (1 + cosine_s) * np.pi / 2\n        alphas = torch.cos(alphas).pow(2)\n        alphas = alphas / alphas[0]\n        betas = 1 - alphas[1:] / alphas[:-1]\n        betas = np.clip(betas, a_min=0, a_max=0.999)\n\n    elif schedule == \"sqrt_linear\":\n        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)\n    elif schedule == \"sqrt\":\n        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5\n    else:\n        raise ValueError(f\"schedule '{schedule}' unknown.\")\n    return betas.numpy()\n\n\ndef make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):\n    if ddim_discr_method == 'uniform':\n        c = num_ddpm_timesteps // num_ddim_timesteps\n        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))\n    elif ddim_discr_method == 'quad':\n        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)\n    else:\n        raise NotImplementedError(f'There is no ddim discretization method called \"{ddim_discr_method}\"')\n\n    # assert ddim_timesteps.shape[0] == num_ddim_timesteps\n    # add one to get the final alpha values right (the ones from first scale to data during sampling)\n    steps_out = ddim_timesteps + 1\n    if verbose:\n        print(f'Selected timesteps for ddim sampler: {steps_out}')\n    return steps_out\n\n\ndef make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):\n    # select alphas for computing the variance schedule\n    alphas = alphacums[ddim_timesteps]\n    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())\n\n    # according the the formula provided in https://arxiv.org/abs/2010.02502\n    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))\n    if verbose:\n        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')\n        print(f'For the chosen value of eta, which is {eta}, '\n              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')\n    return sigmas, alphas, alphas_prev\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if False: # disabled checkpointing to allow requires_grad = False for main model\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\nclass HybridConditioner(nn.Module):\n\n    def __init__(self, c_concat_config, c_crossattn_config):\n        super().__init__()\n        self.concat_conditioner = instantiate_from_config(c_concat_config)\n        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)\n\n    def forward(self, c_concat, c_crossattn):\n        c_concat = self.concat_conditioner(c_concat)\n        c_crossattn = self.crossattn_conditioner(c_crossattn)\n        return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}\n\n\ndef noise_like(shape, device, repeat=False):\n    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))\n    noise = lambda: torch.randn(shape, device=device)\n    return repeat_noise() if repeat else noise()"
  },
  {
    "path": "ldm/modules/distributions/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/distributions/distributions.py",
    "content": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        else:\n            if other is None:\n                return 0.5 * torch.sum(torch.pow(self.mean, 2)\n                                       + self.var - 1.0 - self.logvar,\n                                       dim=[1, 2, 3])\n            else:\n                return 0.5 * torch.sum(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var - 1.0 - self.logvar + other.logvar,\n                    dim=[1, 2, 3])\n\n    def nll(self, sample, dims=[1,2,3]):\n        if self.deterministic:\n            return torch.Tensor([0.])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims)\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n"
  },
  {
    "path": "ldm/modules/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n\n        self.m_name2s_name = {}\n        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates\n                             else torch.tensor(-1,dtype=torch.int))\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                #remove as '.'-character is not allowed in buffers\n                s_name = name.replace('.','')\n                self.m_name2s_name.update({name:s_name})\n                self.register_buffer(s_name,p.clone().detach().data)\n\n        self.collected_params = []\n\n    def forward(self,model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "ldm/modules/embedding_manager.py",
    "content": "import torch\nfrom torch import nn\n\nfrom ldm.data.personalized import per_img_token_list\nfrom transformers import CLIPTokenizer\nfrom functools import partial\n\nDEFAULT_PLACEHOLDER_TOKEN = [\"*\"]\n\nPROGRESSIVE_SCALE = 2000\n\ndef get_clip_token_for_string(tokenizer, string):\n    batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True,\n                               return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n    tokens = batch_encoding[\"input_ids\"]\n    assert torch.count_nonzero(tokens - 49407) == 2, f\"String '{string}' maps to more than a single token. Please use another string\"\n\n    return tokens[0, 1]\n\ndef get_bert_token_for_string(tokenizer, string):\n    token = tokenizer(string)\n    assert torch.count_nonzero(token) == 3, f\"String '{string}' maps to more than a single token. Please use another string\"\n\n    token = token[0, 1]\n\n    return token\n\ndef get_embedding_for_clip_token(embedder, token):\n    return embedder(token.unsqueeze(0))[0, 0]\n\n\nclass EmbeddingManager(nn.Module):\n    def __init__(\n            self,\n            embedder,\n            placeholder_strings=None,\n            initializer_words=None,\n            per_image_tokens=False,\n            num_vectors_per_token=1,\n            progressive_words=False,\n            **kwargs\n    ):\n        super().__init__()\n\n        self.string_to_token_dict = {}\n        \n        self.string_to_param_dict = nn.ParameterDict()\n\n        self.initial_embeddings = nn.ParameterDict() # These should not be optimized\n\n        self.progressive_words = progressive_words\n        self.progressive_counter = 0\n\n        self.max_vectors_per_token = num_vectors_per_token\n\n        if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder\n            self.is_clip = True\n            get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer)\n            get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings)\n            token_dim = 768\n        else: # using LDM's BERT encoder\n            self.is_clip = False\n            get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn)\n            get_embedding_for_tkn = embedder.transformer.token_emb\n            token_dim = 1280\n\n        if per_image_tokens:\n            placeholder_strings.extend(per_img_token_list)\n\n        for idx, placeholder_string in enumerate(placeholder_strings):\n            \n            token = get_token_for_string(placeholder_string)\n\n            if initializer_words and idx < len(initializer_words):\n                init_word_token = get_token_for_string(initializer_words[idx])\n\n                with torch.no_grad():\n                    init_word_embedding = get_embedding_for_tkn(init_word_token.cpu())\n\n                token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True)\n                self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False)\n            else:\n                token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True))\n            \n            self.string_to_token_dict[placeholder_string] = token\n            self.string_to_param_dict[placeholder_string] = token_params\n\n    def forward(\n            self,\n            tokenized_text,\n            embedded_text,\n    ):\n        b, n, device = *tokenized_text.shape, tokenized_text.device\n\n        for placeholder_string, placeholder_token in self.string_to_token_dict.items():\n\n            placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device)\n\n            if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement\n                placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device))\n                embedded_text[placeholder_idx] = placeholder_embedding\n            else: # otherwise, need to insert and keep track of changing indices\n                if self.progressive_words:\n                    self.progressive_counter += 1\n                    max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE\n                else:\n                    max_step_tokens = self.max_vectors_per_token\n\n                num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens)\n\n                placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device))\n\n                if placeholder_rows.nelement() == 0:\n                    continue\n\n                sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True)\n                sorted_rows = placeholder_rows[sort_idx]\n\n                for idx in range(len(sorted_rows)):\n                    row = sorted_rows[idx]\n                    col = sorted_cols[idx]\n\n                    new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n]\n                    new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n]\n\n                    embedded_text[row]  = new_embed_row\n                    tokenized_text[row] = new_token_row\n\n        return embedded_text\n\n    def save(self, ckpt_path):\n        torch.save({\"string_to_token\": self.string_to_token_dict,\n                    \"string_to_param\": self.string_to_param_dict}, ckpt_path)\n\n    def load(self, ckpt_path):\n        ckpt = torch.load(ckpt_path, map_location='cpu')\n\n        self.string_to_token_dict = ckpt[\"string_to_token\"]\n        self.string_to_param_dict = ckpt[\"string_to_param\"]\n\n    def get_embedding_norms_squared(self):\n        all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim\n        param_norm_squared = (all_params * all_params).sum(axis=-1)              # num_placeholders\n\n        return param_norm_squared\n\n    def embedding_parameters(self):\n        return self.string_to_param_dict.parameters()\n\n    def embedding_to_coarse_loss(self):\n        \n        loss = 0.\n        num_embeddings = len(self.initial_embeddings)\n\n        for key in self.initial_embeddings:\n            optimized = self.string_to_param_dict[key]\n            coarse = self.initial_embeddings[key].clone().to(optimized.device)\n\n            loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings\n\n        return loss"
  },
  {
    "path": "ldm/modules/encoders/__init__.py",
    "content": ""
  },
  {
    "path": "ldm/modules/encoders/modules.py",
    "content": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom transformers import CLIPTokenizer, CLIPTextModel\nimport kornia\n\nfrom ldm.modules.x_transformer import Encoder, TransformerWrapper  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test\n\ndef _expand_mask(mask, dtype, tgt_len = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\ndef _build_causal_attention_mask(bsz, seq_len, dtype):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)\n        mask.fill_(torch.tensor(torch.finfo(dtype).min))\n        mask.triu_(1)  # zero out the lower diagonal\n        mask = mask.unsqueeze(1)  # expand mask\n        return mask\n\nclass AbstractEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def encode(self, *args, **kwargs):\n        raise NotImplementedError\n\n\n\nclass ClassEmbedder(nn.Module):\n    def __init__(self, embed_dim, n_classes=1000, key='class'):\n        super().__init__()\n        self.key = key\n        self.embedding = nn.Embedding(n_classes, embed_dim)\n\n    def forward(self, batch, key=None):\n        if key is None:\n            key = self.key\n        # this is for use in crossattn\n        c = batch[key][:, None]\n        c = self.embedding(c)\n        return c\n\n\nclass TransformerEmbedder(AbstractEncoder):\n    \"\"\"Some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=\"cuda\"):\n        super().__init__()\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer))\n\n    def forward(self, tokens):\n        tokens = tokens.to(self.device)  # meh\n        z = self.transformer(tokens, return_embeddings=True)\n        return z\n\n    def encode(self, x):\n        return self(x)\n\n\nclass BERTTokenizer(AbstractEncoder):\n    \"\"\" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)\"\"\"\n    def __init__(self, device=\"cuda\", vq_interface=True, max_length=77):\n        super().__init__()\n        from transformers import BertTokenizerFast  # TODO: add to reuquirements\n        self.tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n        self.device = device\n        self.vq_interface = vq_interface\n        self.max_length = max_length\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        return tokens\n\n    @torch.no_grad()\n    def encode(self, text):\n        tokens = self(text)\n        if not self.vq_interface:\n            return tokens\n        return None, None, [None, None, tokens]\n\n    def decode(self, text):\n        return text\n\n\nclass BERTEmbedder(AbstractEncoder):\n    \"\"\"Uses the BERT tokenizr model and add some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,\n                 device=\"cuda\",use_tokenizer=True, embedding_dropout=0.0):\n        super().__init__()\n        self.use_tknz_fn = use_tokenizer\n        if self.use_tknz_fn:\n            self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer),\n                                              emb_dropout=embedding_dropout)\n\n    def forward(self, text, embedding_manager=None):\n        if self.use_tknz_fn:\n            tokens = self.tknz_fn(text)#.to(self.device)\n        else:\n            tokens = text\n        z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)\n        return z\n\n    def encode(self, text, **kwargs):\n        # output of length 77\n        return self(text, **kwargs)\n\nclass SpatialRescaler(nn.Module):\n    def __init__(self,\n                 n_stages=1,\n                 method='bilinear',\n                 multiplier=0.5,\n                 in_channels=3,\n                 out_channels=None,\n                 bias=False):\n        super().__init__()\n        self.n_stages = n_stages\n        assert self.n_stages >= 0\n        assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']\n        self.multiplier = multiplier\n        self.interpolator = partial(torch.nn.functional.interpolate, mode=method)\n        self.remap_output = out_channels is not None\n        if self.remap_output:\n            print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')\n            self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)\n\n    def forward(self,x):\n        for stage in range(self.n_stages):\n            x = self.interpolator(x, scale_factor=self.multiplier)\n\n\n        if self.remap_output:\n            x = self.channel_mapper(x)\n        return x\n\n    def encode(self, x):\n        return self(x)\n\nclass FrozenCLIPEmbedder(AbstractEncoder):\n    \"\"\"Uses the CLIP transformer encoder for text (from Hugging Face)\"\"\"\n    def __init__(self, version=\"openai/clip-vit-large-patch14\", device=\"cuda\", max_length=77):\n        super().__init__()\n        self.tokenizer = CLIPTokenizer.from_pretrained(version)\n        self.transformer = CLIPTextModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length\n\n        def embedding_forward(\n                self,\n                input_ids = None,\n                position_ids = None,\n                inputs_embeds = None,\n                embedding_manager = None,\n            ) -> torch.Tensor:\n\n                seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n\n                if position_ids is None:\n                    position_ids = self.position_ids[:, :seq_length]\n\n                if inputs_embeds is None:\n                    inputs_embeds = self.token_embedding(input_ids)\n\n                if embedding_manager is not None:\n                    inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n\n\n                position_embeddings = self.position_embedding(position_ids)\n                embeddings = inputs_embeds + position_embeddings\n                \n                return embeddings      \n\n        self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)\n\n        def encoder_forward(\n            self,\n            inputs_embeds,\n            attention_mask = None,\n            causal_attention_mask = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n            encoder_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n\n            hidden_states = inputs_embeds\n            for idx, encoder_layer in enumerate(self.layers):\n                if output_hidden_states:\n                    encoder_states = encoder_states + (hidden_states,)\n\n                layer_outputs = encoder_layer(\n                    hidden_states,\n                    attention_mask,\n                    causal_attention_mask,\n                    output_attentions=output_attentions,\n                )\n\n                hidden_states = layer_outputs[0]\n\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n\n            return hidden_states\n\n        self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)\n\n\n        def text_encoder_forward(\n            self,\n            input_ids = None,\n            attention_mask = None,\n            position_ids = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n            embedding_manager = None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n            if input_ids is None:\n                raise ValueError(\"You have to specify either input_ids\")\n\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n\n            hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)\n\n            bsz, seq_len = input_shape\n            # CLIP's text model uses causal mask, prepare it here.\n            # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n            causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(\n                hidden_states.device\n            )\n\n            # expand attention_mask\n            if attention_mask is not None:\n                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n                attention_mask = _expand_mask(attention_mask, hidden_states.dtype)\n\n            last_hidden_state = self.encoder(\n                inputs_embeds=hidden_states,\n                attention_mask=attention_mask,\n                causal_attention_mask=causal_attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n            last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n            return last_hidden_state\n\n        self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)\n\n        def transformer_forward(\n            self,\n            input_ids = None,\n            attention_mask = None,\n            position_ids = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n            embedding_manager = None,\n        ):\n            return self.text_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                embedding_manager = embedding_manager\n            )\n\n        self.transformer.forward = transformer_forward.__get__(self.transformer)\n\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text, **kwargs):\n        # import ipdb; ipdb.set_trace()\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)        \n        z = self.transformer(input_ids=tokens, **kwargs)\n\n        return z\n\n    def encode(self, text, **kwargs):\n        return self(text, **kwargs)\n\n\nclass FrozenCLIPTextEmbedder(nn.Module):\n    \"\"\"\n    Uses the CLIP transformer encoder for text.\n    \"\"\"\n    def __init__(self, version='ViT-L/14', device=\"cuda\", max_length=77, n_repeat=1, normalize=True):\n        super().__init__()\n        self.model, _ = clip.load(version, jit=False, device=\"cpu\")\n        self.device = device\n        self.max_length = max_length\n        self.n_repeat = n_repeat\n        self.normalize = normalize\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        tokens = clip.tokenize(text).to(self.device)\n        z = self.model.encode_text(tokens)\n        if self.normalize:\n            z = z / torch.linalg.norm(z, dim=1, keepdim=True)\n        return z\n\n    def encode(self, text):\n        z = self(text)\n        if z.ndim==2:\n            z = z[:, None, :]\n        z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)\n        return z\n\n\nclass FrozenClipImageEmbedder(nn.Module):\n    \"\"\"\n        Uses the CLIP image encoder.\n        \"\"\"\n    def __init__(\n            self,\n            model,\n            jit=False,\n            device='cuda' if torch.cuda.is_available() else 'cpu',\n            antialias=False,\n        ):\n        super().__init__()\n        self.model, _ = clip.load(name=model, device=device, jit=jit)\n\n        self.antialias = antialias\n\n        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)\n        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)\n\n    def preprocess(self, x):\n        # normalize to [0,1]\n        x = kornia.geometry.resize(x, (224, 224),\n                                   interpolation='bicubic',align_corners=True,\n                                   antialias=self.antialias)\n        x = (x + 1.) / 2.\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        # x is assumed to be in range [-1,1]\n        return self.model.encode_image(self.preprocess(x))\n\n\nif __name__ == \"__main__\":\n    from ldm.util import count_params\n    model = FrozenCLIPEmbedder()\n    count_params(model, verbose=True)"
  },
  {
    "path": "ldm/modules/encoders/modules_bak.py",
    "content": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom transformers import CLIPTokenizer, CLIPTextModel\nimport kornia\n\nfrom ldm.modules.x_transformer import Encoder, TransformerWrapper  # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test\n\ndef _expand_mask(mask, dtype, tgt_len = None):\n    \"\"\"\n    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.\n    \"\"\"\n    bsz, src_len = mask.size()\n    tgt_len = tgt_len if tgt_len is not None else src_len\n\n    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)\n\n    inverted_mask = 1.0 - expanded_mask\n\n    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)\n\ndef _build_causal_attention_mask(bsz, seq_len, dtype):\n        # lazily create causal attention mask, with full attention between the vision tokens\n        # pytorch uses additive attention mask; fill with -inf\n        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)\n        mask.fill_(torch.tensor(torch.finfo(dtype).min))\n        mask.triu_(1)  # zero out the lower diagonal\n        mask = mask.unsqueeze(1)  # expand mask\n        return mask\n\nclass AbstractEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def encode(self, *args, **kwargs):\n        raise NotImplementedError\n\n\n\nclass ClassEmbedder(nn.Module):\n    def __init__(self, embed_dim, n_classes=1000, key='class'):\n        super().__init__()\n        self.key = key\n        self.embedding = nn.Embedding(n_classes, embed_dim)\n\n    def forward(self, batch, key=None):\n        if key is None:\n            key = self.key\n        # this is for use in crossattn\n        c = batch[key][:, None]\n        c = self.embedding(c)\n        return c\n\n\nclass TransformerEmbedder(AbstractEncoder):\n    \"\"\"Some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device=\"cuda\"):\n        super().__init__()\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer))\n\n    def forward(self, tokens):\n        tokens = tokens.to(self.device)  # meh\n        z = self.transformer(tokens, return_embeddings=True)\n        return z\n\n    def encode(self, x):\n        return self(x)\n\n\nclass BERTTokenizer(AbstractEncoder):\n    \"\"\" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)\"\"\"\n    def __init__(self, device=\"cuda\", vq_interface=True, max_length=77):\n        super().__init__()\n        from transformers import BertTokenizerFast  # TODO: add to reuquirements\n        self.tokenizer = BertTokenizerFast.from_pretrained(\"bert-base-uncased\")\n        self.device = device\n        self.vq_interface = vq_interface\n        self.max_length = max_length\n\n    def forward(self, text):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)\n        return tokens\n\n    @torch.no_grad()\n    def encode(self, text):\n        tokens = self(text)\n        if not self.vq_interface:\n            return tokens\n        return None, None, [None, None, tokens]\n\n    def decode(self, text):\n        return text\n\n\nclass BERTEmbedder(AbstractEncoder):\n    \"\"\"Uses the BERT tokenizr model and add some transformer encoder layers\"\"\"\n    def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,\n                 device=\"cuda\",use_tokenizer=True, embedding_dropout=0.0):\n        super().__init__()\n        self.use_tknz_fn = use_tokenizer\n        if self.use_tknz_fn:\n            self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)\n        self.device = device\n        self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,\n                                              attn_layers=Encoder(dim=n_embed, depth=n_layer),\n                                              emb_dropout=embedding_dropout)\n\n    def forward(self, text, embedding_manager=None):\n        if self.use_tknz_fn:\n            tokens = self.tknz_fn(text)#.to(self.device)\n        else:\n            tokens = text\n        z = self.transformer(tokens, return_embeddings=True, embedding_manager=embedding_manager)\n        return z\n\n    def encode(self, text, **kwargs):\n        # output of length 77\n        return self(text, **kwargs)\n\nclass SpatialRescaler(nn.Module):\n    def __init__(self,\n                 n_stages=1,\n                 method='bilinear',\n                 multiplier=0.5,\n                 in_channels=3,\n                 out_channels=None,\n                 bias=False):\n        super().__init__()\n        self.n_stages = n_stages\n        assert self.n_stages >= 0\n        assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']\n        self.multiplier = multiplier\n        self.interpolator = partial(torch.nn.functional.interpolate, mode=method)\n        self.remap_output = out_channels is not None\n        if self.remap_output:\n            print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')\n            self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)\n\n    def forward(self,x):\n        for stage in range(self.n_stages):\n            x = self.interpolator(x, scale_factor=self.multiplier)\n\n\n        if self.remap_output:\n            x = self.channel_mapper(x)\n        return x\n\n    def encode(self, x):\n        return self(x)\n\nclass FrozenCLIPEmbedder(AbstractEncoder):\n    \"\"\"Uses the CLIP transformer encoder for text (from Hugging Face)\"\"\"\n    def __init__(self, version=\"openai/clip-vit-large-patch14\", device=\"cuda\", max_length=77):\n        super().__init__()\n        self.tokenizer = CLIPTokenizer.from_pretrained(version)\n        self.transformer = CLIPTextModel.from_pretrained(version)\n        self.device = device\n        self.max_length = max_length\n        self.freeze()\n\n        def embedding_forward(\n                self,\n                input_ids = None,\n                position_ids = None,\n                inputs_embeds = None,\n                embedding_manager = None,\n            ) -> torch.Tensor:\n\n                seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n\n                if position_ids is None:\n                    position_ids = self.position_ids[:, :seq_length]\n\n                if inputs_embeds is None:\n                    inputs_embeds = self.token_embedding(input_ids)\n\n                if embedding_manager is not None:\n                    inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n\n\n                position_embeddings = self.position_embedding(position_ids)\n                embeddings = inputs_embeds + position_embeddings\n                \n                return embeddings      \n\n        self.transformer.text_model.embeddings.forward = embedding_forward.__get__(self.transformer.text_model.embeddings)\n\n        def encoder_forward(\n            self,\n            inputs_embeds,\n            attention_mask = None,\n            causal_attention_mask = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n            encoder_states = () if output_hidden_states else None\n            all_attentions = () if output_attentions else None\n\n            hidden_states = inputs_embeds\n            for idx, encoder_layer in enumerate(self.layers):\n                if output_hidden_states:\n                    encoder_states = encoder_states + (hidden_states,)\n\n                layer_outputs = encoder_layer(\n                    hidden_states,\n                    attention_mask,\n                    causal_attention_mask,\n                    output_attentions=output_attentions,\n                )\n\n                hidden_states = layer_outputs[0]\n\n                if output_attentions:\n                    all_attentions = all_attentions + (layer_outputs[1],)\n\n            if output_hidden_states:\n                encoder_states = encoder_states + (hidden_states,)\n\n            return hidden_states\n\n        self.transformer.text_model.encoder.forward = encoder_forward.__get__(self.transformer.text_model.encoder)\n\n\n        def text_encoder_forward(\n            self,\n            input_ids = None,\n            attention_mask = None,\n            position_ids = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n            embedding_manager = None,\n        ):\n            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n            output_hidden_states = (\n                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n            )\n            return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n            if input_ids is None:\n                raise ValueError(\"You have to specify either input_ids\")\n\n            input_shape = input_ids.size()\n            input_ids = input_ids.view(-1, input_shape[-1])\n\n            hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, embedding_manager=embedding_manager)\n\n            bsz, seq_len = input_shape\n            # CLIP's text model uses causal mask, prepare it here.\n            # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324\n            causal_attention_mask = _build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(\n                hidden_states.device\n            )\n\n            # expand attention_mask\n            if attention_mask is not None:\n                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]\n                attention_mask = _expand_mask(attention_mask, hidden_states.dtype)\n\n            last_hidden_state = self.encoder(\n                inputs_embeds=hidden_states,\n                attention_mask=attention_mask,\n                causal_attention_mask=causal_attention_mask,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n            )\n\n            last_hidden_state = self.final_layer_norm(last_hidden_state)\n\n            return last_hidden_state\n\n        self.transformer.text_model.forward = text_encoder_forward.__get__(self.transformer.text_model)\n\n        def transformer_forward(\n            self,\n            input_ids = None,\n            attention_mask = None,\n            position_ids = None,\n            output_attentions = None,\n            output_hidden_states = None,\n            return_dict = None,\n            embedding_manager = None,\n        ):\n            return self.text_model(\n                input_ids=input_ids,\n                attention_mask=attention_mask,\n                position_ids=position_ids,\n                output_attentions=output_attentions,\n                output_hidden_states=output_hidden_states,\n                return_dict=return_dict,\n                embedding_manager = embedding_manager\n            )\n\n        self.transformer.forward = transformer_forward.__get__(self.transformer)\n\n\n    # def update_embedding_func(self, embedding_manager):\n    #     text_model = self.transformer.text_model\n    #     # text_model.old_embeddings = text_model.embeddings\n\n    #     # def new_embeddings(\n    #     #         input_ids = None,\n    #     #         position_ids = None,\n    #     #         inputs_embeds = None,\n    #     #     ) -> torch.Tensor:\n\n    #     #         seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n\n    #     #         if position_ids is None:\n    #     #             position_ids = text_model.old_embeddings.position_ids[:, :seq_length]\n\n    #     #         if inputs_embeds is None:\n    #     #             inputs_embeds = text_model.old_embeddings.token_embedding(input_ids)\n\n                    \n    #     #         inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n\n    #     #         position_embeddings = text_model.old_embeddings.position_embedding(position_ids)\n    #     #         embeddings = inputs_embeds + position_embeddings\n            \n    #     #         return embeddings  \n\n    #     # del text_model.embeddings\n    #     # text_model.embeddings = new_embeddings\n\n    #     # class NewEmbeddings(torch.nn.Module):\n\n    #     #     def __init__(self, orig_embedder):\n    #     #         super().__init__()\n    #     #         self.orig_embedder = orig_embedder\n            \n    #     #     def forward(\n    #     #         self,\n    #     #         input_ids = None,\n    #     #         position_ids = None,\n    #     #         inputs_embeds = None,\n    #     #     ) -> torch.Tensor:\n\n    #     #         seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n\n    #     #         if position_ids is None:\n    #     #             position_ids = self.orig_embedder.position_ids[:, :seq_length]\n\n    #     #         if inputs_embeds is None:\n    #     #             inputs_embeds = self.orig_embedder.token_embedding(input_ids)\n                    \n    #     #         inputs_embeds = embedding_manager(input_ids, inputs_embeds)\n\n    #     #         position_embeddings = self.orig_embedder.position_embedding(position_ids)\n    #     #         embeddings = inputs_embeds + position_embeddings\n            \n    #     #         return embeddings      \n\n    #     # # self.new_embeddings = \n    #     # # text_model.embeddings = new_embeddings.__call__.__get__(text_model)\n    #     # text_model.embeddings = NewEmbeddings(text_model.embeddings)\n        \n    #     class NewEmbeddings(torch.nn.Module):\n\n    #         def __init__(self, orig_embedder, embedding_manager):\n    #             super().__init__()\n    #             self.embedding_manager = embedding_manager\n    #             self.orig_embedder     = orig_embedder\n            \n    #         def forward(\n    #             self,\n    #             input_ids = None,\n    #             position_ids = None,\n    #             inputs_embeds = None,\n    #         ) -> torch.Tensor:\n\n    #             seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]\n\n    #             if position_ids is None:\n    #                 position_ids = self.orig_embedder.position_ids[:, :seq_length]\n\n    #             if inputs_embeds is None:\n    #                 inputs_embeds = self.orig_embedder.token_embedding(input_ids)\n                \n    #             # init_embeds = inputs_embeds.clone()\n    #             inputs_embeds = self.embedding_manager(input_ids, inputs_embeds)\n\n    #             # print(inputs_embeds - init_embeds)\n    #             # print((inputs_embeds - init_embeds).max())\n    #             # exit(0)\n\n    #             position_embeddings = self.orig_embedder.position_embedding(position_ids)\n    #             embeddings = inputs_embeds + position_embeddings\n                \n    #             return embeddings      \n\n    #     # self.new_embeddings = \n    #     # text_model.embeddings = new_embeddings.__call__.__get__(text_model)\n    #     text_model.embeddings = NewEmbeddings(text_model.embeddings, embedding_manager)\n\n    def freeze(self):\n        self.transformer = self.transformer.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text, **kwargs):\n        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,\n                                        return_overflowing_tokens=False, padding=\"max_length\", return_tensors=\"pt\")\n        tokens = batch_encoding[\"input_ids\"].to(self.device)        \n        z = self.transformer(input_ids=tokens, **kwargs)\n\n        return z\n\n    def encode(self, text, **kwargs):\n        return self(text, **kwargs)\n\n\nclass FrozenCLIPTextEmbedder(nn.Module):\n    \"\"\"\n    Uses the CLIP transformer encoder for text.\n    \"\"\"\n    def __init__(self, version='ViT-L/14', device=\"cuda\", max_length=77, n_repeat=1, normalize=True):\n        super().__init__()\n        self.model, _ = clip.load(version, jit=False, device=\"cpu\")\n        self.device = device\n        self.max_length = max_length\n        self.n_repeat = n_repeat\n        self.normalize = normalize\n\n    def freeze(self):\n        self.model = self.model.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n\n    def forward(self, text):\n        tokens = clip.tokenize(text).to(self.device)\n        z = self.model.encode_text(tokens)\n        if self.normalize:\n            z = z / torch.linalg.norm(z, dim=1, keepdim=True)\n        return z\n\n    def encode(self, text):\n        z = self(text)\n        if z.ndim==2:\n            z = z[:, None, :]\n        z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)\n        return z\n\n\nclass FrozenClipImageEmbedder(nn.Module):\n    \"\"\"\n        Uses the CLIP image encoder.\n        \"\"\"\n    def __init__(\n            self,\n            model,\n            jit=False,\n            device='cuda' if torch.cuda.is_available() else 'cpu',\n            antialias=False,\n        ):\n        super().__init__()\n        self.model, _ = clip.load(name=model, device=device, jit=jit)\n\n        self.antialias = antialias\n\n        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)\n        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)\n\n    def preprocess(self, x):\n        # normalize to [0,1]\n        x = kornia.geometry.resize(x, (224, 224),\n                                   interpolation='bicubic',align_corners=True,\n                                   antialias=self.antialias)\n        x = (x + 1.) / 2.\n        # renormalize according to clip\n        x = kornia.enhance.normalize(x, self.mean, self.std)\n        return x\n\n    def forward(self, x):\n        # x is assumed to be in range [-1,1]\n        return self.model.encode_image(self.preprocess(x))\n\n\nif __name__ == \"__main__\":\n    from ldm.util import count_params\n    model = FrozenCLIPEmbedder()\n    count_params(model, verbose=True)"
  },
  {
    "path": "ldm/modules/image_degradation/__init__.py",
    "content": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light\n"
  },
  {
    "path": "ldm/modules/image_degradation/bsrgan.py",
    "content": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\ndef modcrop_np(img, sf):\n    '''\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    '''\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[:w - w % sf, :h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\" generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    '''\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    '''\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):\n    \"\"\"\"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array([[np.cos(theta), -np.sin(theta)],\n                  [np.sin(theta), np.cos(theta)]])\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    '''\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    '''\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    '''\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    '''\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    ''' blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    ''' bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    '''\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    ''' blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)\n    else:\n        k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())\n    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.\n        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(30, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                             interpolation=random.choice([1, 2, 3]))\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                                 interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                               interpolation=random.choice([1, 2, 3]))\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        elif i == 1:\n            image = add_blur(image, sf=sf)\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),\n                                   interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {\"image\":image}\n    return example\n\n\n# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...\ndef degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):\n    \"\"\"\n    This is an extended degradation model by combining\n    the degradation models of BSRGAN and Real-ESRGAN\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    use_shuffle: the degradation shuffle\n    use_sharp: sharpening the img\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    if use_sharp:\n        img = add_sharpening(img)\n    hq = img.copy()\n\n    if random.random() < shuffle_prob:\n        shuffle_order = random.sample(range(13), 13)\n    else:\n        shuffle_order = list(range(13))\n        # local shuffle for noise, JPEG is always the last one\n        shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))\n        shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))\n\n    poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1\n\n    for i in shuffle_order:\n        if i == 0:\n            img = add_blur(img, sf=sf)\n        elif i == 1:\n            img = add_resize(img, sf=sf)\n        elif i == 2:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 3:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 4:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 5:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        elif i == 6:\n            img = add_JPEG_noise(img)\n        elif i == 7:\n            img = add_blur(img, sf=sf)\n        elif i == 8:\n            img = add_resize(img, sf=sf)\n        elif i == 9:\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)\n        elif i == 10:\n            if random.random() < poisson_prob:\n                img = add_Poisson_noise(img)\n        elif i == 11:\n            if random.random() < speckle_prob:\n                img = add_speckle_noise(img)\n        elif i == 12:\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n        else:\n            print('check the shuffle!')\n\n    # resize to desired size\n    img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),\n                     interpolation=random.choice([1, 2, 3]))\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf, lq_patchsize)\n\n    return img, hq\n\n\nif __name__ == '__main__':\n\tprint(\"hey\")\n\timg = util.imread_uint('utils/test.png', 3)\n\tprint(img)\n\timg = util.uint2single(img)\n\tprint(img)\n\timg = img[:448, :448]\n\th = img.shape[0] // 4\n\tprint(\"resizing to\", h)\n\tsf = 4\n\tdeg_fn = partial(degradation_bsrgan_variant, sf=sf)\n\tfor i in range(20):\n\t\tprint(i)\n\t\timg_lq = deg_fn(img)\n\t\tprint(img_lq)\n\t\timg_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)[\"image\"]\n\t\tprint(img_lq.shape)\n\t\tprint(\"bicubic\", img_lq_bicubic.shape)\n\t\tprint(img_hq.shape)\n\t\tlq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n\t\t                        interpolation=0)\n\t\tlq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n\t\t                        interpolation=0)\n\t\timg_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)\n\t\tutil.imsave(img_concat, str(i) + '.png')\n\n\n"
  },
  {
    "path": "ldm/modules/image_degradation/bsrgan_light.py",
    "content": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom scipy import ndimage\nimport scipy\nimport scipy.stats as ss\nfrom scipy.interpolate import interp2d\nfrom scipy.linalg import orth\nimport albumentations\n\nimport ldm.modules.image_degradation.utils_image as util\n\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# --------------------------------------------\n#\n# Kai Zhang (cskaizhang@gmail.com)\n# https://github.com/cszn\n# From 2019/03--2021/08\n# --------------------------------------------\n\"\"\"\n\n\ndef modcrop_np(img, sf):\n    '''\n    Args:\n        img: numpy image, WxH or WxHxC\n        sf: scale factor\n    Return:\n        cropped image\n    '''\n    w, h = img.shape[:2]\n    im = np.copy(img)\n    return im[:w - w % sf, :h - h % sf, ...]\n\n\n\"\"\"\n# --------------------------------------------\n# anisotropic Gaussian kernels\n# --------------------------------------------\n\"\"\"\n\n\ndef analytic_kernel(k):\n    \"\"\"Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)\"\"\"\n    k_size = k.shape[0]\n    # Calculate the big kernels size\n    big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))\n    # Loop over the small kernel to fill the big one\n    for r in range(k_size):\n        for c in range(k_size):\n            big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k\n    # Crop the edges of the big kernel to ignore very small values and increase run time of SR\n    crop = k_size // 2\n    cropped_big_k = big_k[crop:-crop, crop:-crop]\n    # Normalize to 1\n    return cropped_big_k / cropped_big_k.sum()\n\n\ndef anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):\n    \"\"\" generate an anisotropic Gaussian kernel\n    Args:\n        ksize : e.g., 15, kernel size\n        theta : [0,  pi], rotation angle range\n        l1    : [0.1,50], scaling of eigenvalues\n        l2    : [0.1,l1], scaling of eigenvalues\n        If l1 = l2, will get an isotropic Gaussian kernel.\n    Returns:\n        k     : kernel\n    \"\"\"\n\n    v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))\n    V = np.array([[v[0], v[1]], [v[1], -v[0]]])\n    D = np.array([[l1, 0], [0, l2]])\n    Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))\n    k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)\n\n    return k\n\n\ndef gm_blur_kernel(mean, cov, size=15):\n    center = size / 2.0 + 0.5\n    k = np.zeros([size, size])\n    for y in range(size):\n        for x in range(size):\n            cy = y - center + 1\n            cx = x - center + 1\n            k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)\n\n    k = k / np.sum(k)\n    return k\n\n\ndef shift_pixel(x, sf, upper_left=True):\n    \"\"\"shift pixel for super-resolution with different scale factors\n    Args:\n        x: WxHxC or WxH\n        sf: scale factor\n        upper_left: shift direction\n    \"\"\"\n    h, w = x.shape[:2]\n    shift = (sf - 1) * 0.5\n    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)\n    if upper_left:\n        x1 = xv + shift\n        y1 = yv + shift\n    else:\n        x1 = xv - shift\n        y1 = yv - shift\n\n    x1 = np.clip(x1, 0, w - 1)\n    y1 = np.clip(y1, 0, h - 1)\n\n    if x.ndim == 2:\n        x = interp2d(xv, yv, x)(x1, y1)\n    if x.ndim == 3:\n        for i in range(x.shape[-1]):\n            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)\n\n    return x\n\n\ndef blur(x, k):\n    '''\n    x: image, NxcxHxW\n    k: kernel, Nx1xhxw\n    '''\n    n, c = x.shape[:2]\n    p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2\n    x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')\n    k = k.repeat(1, c, 1, 1)\n    k = k.view(-1, 1, k.shape[2], k.shape[3])\n    x = x.view(1, -1, x.shape[2], x.shape[3])\n    x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)\n    x = x.view(n, c, x.shape[2], x.shape[3])\n\n    return x\n\n\ndef gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):\n    \"\"\"\"\n    # modified version of https://github.com/assafshocher/BlindSR_dataset_generator\n    # Kai Zhang\n    # min_var = 0.175 * sf  # variance of the gaussian kernel will be sampled between min_var and max_var\n    # max_var = 2.5 * sf\n    \"\"\"\n    # Set random eigen-vals (lambdas) and angle (theta) for COV matrix\n    lambda_1 = min_var + np.random.rand() * (max_var - min_var)\n    lambda_2 = min_var + np.random.rand() * (max_var - min_var)\n    theta = np.random.rand() * np.pi  # random theta\n    noise = -noise_level + np.random.rand(*k_size) * noise_level * 2\n\n    # Set COV matrix using Lambdas and Theta\n    LAMBDA = np.diag([lambda_1, lambda_2])\n    Q = np.array([[np.cos(theta), -np.sin(theta)],\n                  [np.sin(theta), np.cos(theta)]])\n    SIGMA = Q @ LAMBDA @ Q.T\n    INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]\n\n    # Set expectation position (shifting kernel for aligned image)\n    MU = k_size // 2 - 0.5 * (scale_factor - 1)  # - 0.5 * (scale_factor - k_size % 2)\n    MU = MU[None, None, :, None]\n\n    # Create meshgrid for Gaussian\n    [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))\n    Z = np.stack([X, Y], 2)[:, :, :, None]\n\n    # Calcualte Gaussian for every pixel of the kernel\n    ZZ = Z - MU\n    ZZ_t = ZZ.transpose(0, 1, 3, 2)\n    raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)\n\n    # shift the kernel so it will be centered\n    # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)\n\n    # Normalize the kernel and return\n    # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)\n    kernel = raw_kernel / np.sum(raw_kernel)\n    return kernel\n\n\ndef fspecial_gaussian(hsize, sigma):\n    hsize = [hsize, hsize]\n    siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]\n    std = sigma\n    [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))\n    arg = -(x * x + y * y) / (2 * std * std)\n    h = np.exp(arg)\n    h[h < scipy.finfo(float).eps * h.max()] = 0\n    sumh = h.sum()\n    if sumh != 0:\n        h = h / sumh\n    return h\n\n\ndef fspecial_laplacian(alpha):\n    alpha = max([0, min([alpha, 1])])\n    h1 = alpha / (alpha + 1)\n    h2 = (1 - alpha) / (alpha + 1)\n    h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]\n    h = np.array(h)\n    return h\n\n\ndef fspecial(filter_type, *args, **kwargs):\n    '''\n    python code from:\n    https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py\n    '''\n    if filter_type == 'gaussian':\n        return fspecial_gaussian(*args, **kwargs)\n    if filter_type == 'laplacian':\n        return fspecial_laplacian(*args, **kwargs)\n\n\n\"\"\"\n# --------------------------------------------\n# degradation models\n# --------------------------------------------\n\"\"\"\n\n\ndef bicubic_degradation(x, sf=3):\n    '''\n    Args:\n        x: HxWxC image, [0, 1]\n        sf: down-scale factor\n    Return:\n        bicubicly downsampled LR image\n    '''\n    x = util.imresize_np(x, scale=1 / sf)\n    return x\n\n\ndef srmd_degradation(x, k, sf=3):\n    ''' blur + bicubic downsampling\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2018learning,\n          title={Learning a single convolutional super-resolution network for multiple degradations},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={3262--3271},\n          year={2018}\n        }\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')  # 'nearest' | 'mirror'\n    x = bicubic_degradation(x, sf=sf)\n    return x\n\n\ndef dpsr_degradation(x, k, sf=3):\n    ''' bicubic downsampling + blur\n    Args:\n        x: HxWxC image, [0, 1]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    Reference:\n        @inproceedings{zhang2019deep,\n          title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},\n          author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},\n          booktitle={IEEE Conference on Computer Vision and Pattern Recognition},\n          pages={1671--1681},\n          year={2019}\n        }\n    '''\n    x = bicubic_degradation(x, sf=sf)\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    return x\n\n\ndef classical_degradation(x, k, sf=3):\n    ''' blur + downsampling\n    Args:\n        x: HxWxC image, [0, 1]/[0, 255]\n        k: hxw, double\n        sf: down-scale factor\n    Return:\n        downsampled LR image\n    '''\n    x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')\n    # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))\n    st = 0\n    return x[st::sf, st::sf, ...]\n\n\ndef add_sharpening(img, weight=0.5, radius=50, threshold=10):\n    \"\"\"USM sharpening. borrowed from real-ESRGAN\n    Input image: I; Blurry image: B.\n    1. K = I + weight * (I - B)\n    2. Mask = 1 if abs(I - B) > threshold, else: 0\n    3. Blur mask:\n    4. Out = Mask * K + (1 - Mask) * I\n    Args:\n        img (Numpy array): Input image, HWC, BGR; float32, [0, 1].\n        weight (float): Sharp weight. Default: 1.\n        radius (float): Kernel size of Gaussian blur. Default: 50.\n        threshold (int):\n    \"\"\"\n    if radius % 2 == 0:\n        radius += 1\n    blur = cv2.GaussianBlur(img, (radius, radius), 0)\n    residual = img - blur\n    mask = np.abs(residual) * 255 > threshold\n    mask = mask.astype('float32')\n    soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)\n\n    K = img + weight * residual\n    K = np.clip(K, 0, 1)\n    return soft_mask * K + (1 - soft_mask) * img\n\n\ndef add_blur(img, sf=4):\n    wd2 = 4.0 + sf\n    wd = 2.0 + 0.2 * sf\n\n    wd2 = wd2/4\n    wd = wd/4\n\n    if random.random() < 0.5:\n        l1 = wd2 * random.random()\n        l2 = wd2 * random.random()\n        k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)\n    else:\n        k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())\n    img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')\n\n    return img\n\n\ndef add_resize(img, sf=4):\n    rnum = np.random.rand()\n    if rnum > 0.8:  # up\n        sf1 = random.uniform(1, 2)\n    elif rnum < 0.7:  # down\n        sf1 = random.uniform(0.5 / sf, 1)\n    else:\n        sf1 = 1.0\n    img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))\n    img = np.clip(img, 0.0, 1.0)\n\n    return img\n\n\n# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n#     noise_level = random.randint(noise_level1, noise_level2)\n#     rnum = np.random.rand()\n#     if rnum > 0.6:  # add color Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n#     elif rnum < 0.4:  # add grayscale Gaussian noise\n#         img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n#     else:  # add  noise\n#         L = noise_level2 / 255.\n#         D = np.diag(np.random.rand(3))\n#         U = orth(np.random.rand(3, 3))\n#         conv = np.dot(np.dot(np.transpose(U), D), U)\n#         img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n#     img = np.clip(img, 0.0, 1.0)\n#     return img\n\ndef add_Gaussian_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    rnum = np.random.rand()\n    if rnum > 0.6:  # add color Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:  # add grayscale Gaussian noise\n        img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:  # add  noise\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_speckle_noise(img, noise_level1=2, noise_level2=25):\n    noise_level = random.randint(noise_level1, noise_level2)\n    img = np.clip(img, 0.0, 1.0)\n    rnum = random.random()\n    if rnum > 0.6:\n        img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)\n    elif rnum < 0.4:\n        img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)\n    else:\n        L = noise_level2 / 255.\n        D = np.diag(np.random.rand(3))\n        U = orth(np.random.rand(3, 3))\n        conv = np.dot(np.dot(np.transpose(U), D), U)\n        img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_Poisson_noise(img):\n    img = np.clip((img * 255.0).round(), 0, 255) / 255.\n    vals = 10 ** (2 * random.random() + 2.0)  # [2, 4]\n    if random.random() < 0.5:\n        img = np.random.poisson(img * vals).astype(np.float32) / vals\n    else:\n        img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])\n        img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.\n        noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray\n        img += noise_gray[:, :, np.newaxis]\n    img = np.clip(img, 0.0, 1.0)\n    return img\n\n\ndef add_JPEG_noise(img):\n    quality_factor = random.randint(80, 95)\n    img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)\n    result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])\n    img = cv2.imdecode(encimg, 1)\n    img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)\n    return img\n\n\ndef random_crop(lq, hq, sf=4, lq_patchsize=64):\n    h, w = lq.shape[:2]\n    rnd_h = random.randint(0, h - lq_patchsize)\n    rnd_w = random.randint(0, w - lq_patchsize)\n    lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]\n\n    rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)\n    hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]\n    return lq, hq\n\n\ndef degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = img.shape[:2]\n    img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = img.shape[:2]\n\n    if h < lq_patchsize * sf or w < lq_patchsize * sf:\n        raise ValueError(f'img size ({h1}X{w1}) is too small!')\n\n    hq = img.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),\n                             interpolation=random.choice([1, 2, 3]))\n        else:\n            img = util.imresize_np(img, 1 / 2, True)\n        img = np.clip(img, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            img = add_blur(img, sf=sf)\n\n        elif i == 1:\n            img = add_blur(img, sf=sf)\n\n        elif i == 2:\n            a, b = img.shape[1], img.shape[0]\n            # downsample2\n            if random.random() < 0.75:\n                sf1 = random.uniform(1, 2 * sf)\n                img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),\n                                 interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                img = img[0::sf, 0::sf, ...]  # nearest downsampling\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            img = np.clip(img, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                img = add_JPEG_noise(img)\n\n        elif i == 6:\n            # add processed camera sensor noise\n            if random.random() < isp_prob and isp_model is not None:\n                with torch.no_grad():\n                    img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    img = add_JPEG_noise(img)\n\n    # random crop\n    img, hq = random_crop(img, hq, sf_ori, lq_patchsize)\n\n    return img, hq\n\n\n# todo no isp_model?\ndef degradation_bsrgan_variant(image, sf=4, isp_model=None):\n    \"\"\"\n    This is the degradation model of BSRGAN from the paper\n    \"Designing a Practical Degradation Model for Deep Blind Image Super-Resolution\"\n    ----------\n    sf: scale factor\n    isp_model: camera ISP model\n    Returns\n    -------\n    img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]\n    hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]\n    \"\"\"\n    image = util.uint2single(image)\n    isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25\n    sf_ori = sf\n\n    h1, w1 = image.shape[:2]\n    image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...]  # mod crop\n    h, w = image.shape[:2]\n\n    hq = image.copy()\n\n    if sf == 4 and random.random() < scale2_prob:  # downsample1\n        if np.random.rand() < 0.5:\n            image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),\n                               interpolation=random.choice([1, 2, 3]))\n        else:\n            image = util.imresize_np(image, 1 / 2, True)\n        image = np.clip(image, 0.0, 1.0)\n        sf = 2\n\n    shuffle_order = random.sample(range(7), 7)\n    idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)\n    if idx1 > idx2:  # keep downsample3 last\n        shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]\n\n    for i in shuffle_order:\n\n        if i == 0:\n            image = add_blur(image, sf=sf)\n\n        # elif i == 1:\n        #     image = add_blur(image, sf=sf)\n\n        if i == 0:\n            pass\n\n        elif i == 2:\n            a, b = image.shape[1], image.shape[0]\n            # downsample2\n            if random.random() < 0.8:\n                sf1 = random.uniform(1, 2 * sf)\n                image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),\n                                   interpolation=random.choice([1, 2, 3]))\n            else:\n                k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))\n                k_shifted = shift_pixel(k, sf)\n                k_shifted = k_shifted / k_shifted.sum()  # blur with shifted kernel\n                image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')\n                image = image[0::sf, 0::sf, ...]  # nearest downsampling\n\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 3:\n            # downsample3\n            image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))\n            image = np.clip(image, 0.0, 1.0)\n\n        elif i == 4:\n            # add Gaussian noise\n            image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)\n\n        elif i == 5:\n            # add JPEG noise\n            if random.random() < jpeg_prob:\n                image = add_JPEG_noise(image)\n        #\n        # elif i == 6:\n        #     # add processed camera sensor noise\n        #     if random.random() < isp_prob and isp_model is not None:\n        #         with torch.no_grad():\n        #             img, hq = isp_model.forward(img.copy(), hq)\n\n    # add final JPEG compression noise\n    image = add_JPEG_noise(image)\n    image = util.single2uint(image)\n    example = {\"image\": image}\n    return example\n\n\n\n\nif __name__ == '__main__':\n    print(\"hey\")\n    img = util.imread_uint('utils/test.png', 3)\n    img = img[:448, :448]\n    h = img.shape[0] // 4\n    print(\"resizing to\", h)\n    sf = 4\n    deg_fn = partial(degradation_bsrgan_variant, sf=sf)\n    for i in range(20):\n        print(i)\n        img_hq = img\n        img_lq = deg_fn(img)[\"image\"]\n        img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)\n        print(img_lq)\n        img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)[\"image\"]\n        print(img_lq.shape)\n        print(\"bicubic\", img_lq_bicubic.shape)\n        print(img_hq.shape)\n        lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n                                interpolation=0)\n        lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),\n                                        (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),\n                                        interpolation=0)\n        img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)\n        util.imsave(img_concat, str(i) + '.png')\n"
  },
  {
    "path": "ldm/modules/image_degradation/utils_image.py",
    "content": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nfrom datetime import datetime\n#import matplotlib.pyplot as plt   # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py\n\n\nos.environ[\"KMP_DUPLICATE_LIB_OK\"]=\"TRUE\"\n\n\n'''\n# --------------------------------------------\n# Kai Zhang (github: https://github.com/cszn)\n# 03/Mar/2019\n# --------------------------------------------\n# https://github.com/twhui/SRGAN-pyTorch\n# https://github.com/xinntao/BasicSR\n# --------------------------------------------\n'''\n\n\nIMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef get_timestamp():\n    return datetime.now().strftime('%y%m%d-%H%M%S')\n\n\ndef imshow(x, title=None, cbar=False, figsize=None):\n    plt.figure(figsize=figsize)\n    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')\n    if title:\n        plt.title(title)\n    if cbar:\n        plt.colorbar()\n    plt.show()\n\n\ndef surf(Z, cmap='rainbow', figsize=None):\n    plt.figure(figsize=figsize)\n    ax3 = plt.axes(projection='3d')\n\n    w, h = Z.shape[:2]\n    xx = np.arange(0,w,1)\n    yy = np.arange(0,h,1)\n    X, Y = np.meshgrid(xx, yy)\n    ax3.plot_surface(X,Y,Z,cmap=cmap)\n    #ax3.contour(X,Y,Z, zdim='z',offset=-2，cmap=cmap)\n    plt.show()\n\n\n'''\n# --------------------------------------------\n# get image pathes\n# --------------------------------------------\n'''\n\n\ndef get_image_paths(dataroot):\n    paths = None  # return None if dataroot is None\n    if dataroot is not None:\n        paths = sorted(_get_paths_from_images(dataroot))\n    return paths\n\n\ndef _get_paths_from_images(path):\n    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)\n    images = []\n    for dirpath, _, fnames in sorted(os.walk(path)):\n        for fname in sorted(fnames):\n            if is_image_file(fname):\n                img_path = os.path.join(dirpath, fname)\n                images.append(img_path)\n    assert images, '{:s} has no valid image file'.format(path)\n    return images\n\n\n'''\n# --------------------------------------------\n# split large images into small images \n# --------------------------------------------\n'''\n\n\ndef patches_from_image(img, p_size=512, p_overlap=64, p_max=800):\n    w, h = img.shape[:2]\n    patches = []\n    if w > p_max and h > p_max:\n        w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))\n        h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))\n        w1.append(w-p_size)\n        h1.append(h-p_size)\n#        print(w1)\n#        print(h1)\n        for i in w1:\n            for j in h1:\n                patches.append(img[i:i+p_size, j:j+p_size,:])\n    else:\n        patches.append(img)\n\n    return patches\n\n\ndef imssave(imgs, img_path):\n    \"\"\"\n    imgs: list, N images of size WxHxC\n    \"\"\"\n    img_name, ext = os.path.splitext(os.path.basename(img_path))\n\n    for i, img in enumerate(imgs):\n        if img.ndim == 3:\n            img = img[:, :, [2, 1, 0]]\n        new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')\n        cv2.imwrite(new_path, img)\n\n\ndef split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):\n    \"\"\"\n    split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),\n    and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)\n    will be splitted.\n    Args:\n        original_dataroot:\n        taget_dataroot:\n        p_size: size of small images\n        p_overlap: patch size in training is a good choice\n        p_max: images with smaller size than (p_max)x(p_max) keep unchanged.\n    \"\"\"\n    paths = get_image_paths(original_dataroot)\n    for img_path in paths:\n        # img_name, ext = os.path.splitext(os.path.basename(img_path))\n        img = imread_uint(img_path, n_channels=n_channels)\n        patches = patches_from_image(img, p_size, p_overlap, p_max)\n        imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))\n        #if original_dataroot == taget_dataroot:\n        #del img_path\n\n'''\n# --------------------------------------------\n# makedir\n# --------------------------------------------\n'''\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef mkdirs(paths):\n    if isinstance(paths, str):\n        mkdir(paths)\n    else:\n        for path in paths:\n            mkdir(path)\n\n\ndef mkdir_and_rename(path):\n    if os.path.exists(path):\n        new_name = path + '_archived_' + get_timestamp()\n        print('Path already exists. Rename it to [{:s}]'.format(new_name))\n        os.rename(path, new_name)\n    os.makedirs(path)\n\n\n'''\n# --------------------------------------------\n# read image from path\n# opencv is fast, but read BGR numpy image\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# get uint8 image of size HxWxn_channles (RGB)\n# --------------------------------------------\ndef imread_uint(path, n_channels=3):\n    #  input: path\n    # output: HxWx3(RGB or GGG), or HxWx1 (G)\n    if n_channels == 1:\n        img = cv2.imread(path, 0)  # cv2.IMREAD_GRAYSCALE\n        img = np.expand_dims(img, axis=2)  # HxWx1\n    elif n_channels == 3:\n        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # BGR or G\n        if img.ndim == 2:\n            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)  # GGG\n        else:\n            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # RGB\n    return img\n\n\n# --------------------------------------------\n# matlab's imwrite\n# --------------------------------------------\ndef imsave(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\ndef imwrite(img, img_path):\n    img = np.squeeze(img)\n    if img.ndim == 3:\n        img = img[:, :, [2, 1, 0]]\n    cv2.imwrite(img_path, img)\n\n\n\n# --------------------------------------------\n# get single image of size HxWxn_channles (BGR)\n# --------------------------------------------\ndef read_img(path):\n    # read image by cv2\n    # return: Numpy float32, HWC, BGR, [0,1]\n    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)  # cv2.IMREAD_GRAYSCALE\n    img = img.astype(np.float32) / 255.\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    # some images have 4 channels\n    if img.shape[2] > 3:\n        img = img[:, :, :3]\n    return img\n\n\n'''\n# --------------------------------------------\n# image format conversion\n# --------------------------------------------\n# numpy(single) <--->  numpy(unit)\n# numpy(single) <--->  tensor\n# numpy(unit)   <--->  tensor\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# numpy(single) [0, 1] <--->  numpy(unit)\n# --------------------------------------------\n\n\ndef uint2single(img):\n\n    return np.float32(img/255.)\n\n\ndef single2uint(img):\n\n    return np.uint8((img.clip(0, 1)*255.).round())\n\n\ndef uint162single(img):\n\n    return np.float32(img/65535.)\n\n\ndef single2uint16(img):\n\n    return np.uint16((img.clip(0, 1)*65535.).round())\n\n\n# --------------------------------------------\n# numpy(unit) (HxWxC or HxW) <--->  tensor\n# --------------------------------------------\n\n\n# convert uint to 4-dimensional torch tensor\ndef uint2tensor4(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)\n\n\n# convert uint to 3-dimensional torch tensor\ndef uint2tensor3(img):\n    if img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)\n\n\n# convert 2/3/4-dimensional torch tensor to uint\ndef tensor2uint(img):\n    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    return np.uint8((img*255.0).round())\n\n\n# --------------------------------------------\n# numpy(single) (HxWxC) <--->  tensor\n# --------------------------------------------\n\n\n# convert single (HxWxC) to 3-dimensional torch tensor\ndef single2tensor3(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()\n\n\n# convert single (HxWxC) to 4-dimensional torch tensor\ndef single2tensor4(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)\n\n\n# convert torch tensor to single\ndef tensor2single(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n\n    return img\n\n# convert torch tensor to single\ndef tensor2single3(img):\n    img = img.data.squeeze().float().cpu().numpy()\n    if img.ndim == 3:\n        img = np.transpose(img, (1, 2, 0))\n    elif img.ndim == 2:\n        img = np.expand_dims(img, axis=2)\n    return img\n\n\ndef single2tensor5(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)\n\n\ndef single32tensor5(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)\n\n\ndef single42tensor4(img):\n    return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()\n\n\n# from skimage.io import imread, imsave\ndef tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):\n    '''\n    Converts a torch Tensor into an image Numpy array of BGR channel order\n    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order\n    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)\n    '''\n    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # squeeze first, then clamp\n    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]\n    n_dim = tensor.dim()\n    if n_dim == 4:\n        n_img = len(tensor)\n        img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 3:\n        img_np = tensor.numpy()\n        img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR\n    elif n_dim == 2:\n        img_np = tensor.numpy()\n    else:\n        raise TypeError(\n            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))\n    if out_type == np.uint8:\n        img_np = (img_np * 255.0).round()\n        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.\n    return img_np.astype(out_type)\n\n\n'''\n# --------------------------------------------\n# Augmentation, flipe and/or rotate\n# --------------------------------------------\n# The following two are enough.\n# (1) augmet_img: numpy image of WxHxC or WxH\n# (2) augment_img_tensor4: tensor image 1xCxWxH\n# --------------------------------------------\n'''\n\n\ndef augment_img(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return np.flipud(np.rot90(img))\n    elif mode == 2:\n        return np.flipud(img)\n    elif mode == 3:\n        return np.rot90(img, k=3)\n    elif mode == 4:\n        return np.flipud(np.rot90(img, k=2))\n    elif mode == 5:\n        return np.rot90(img)\n    elif mode == 6:\n        return np.rot90(img, k=2)\n    elif mode == 7:\n        return np.flipud(np.rot90(img, k=3))\n\n\ndef augment_img_tensor4(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.rot90(1, [2, 3]).flip([2])\n    elif mode == 2:\n        return img.flip([2])\n    elif mode == 3:\n        return img.rot90(3, [2, 3])\n    elif mode == 4:\n        return img.rot90(2, [2, 3]).flip([2])\n    elif mode == 5:\n        return img.rot90(1, [2, 3])\n    elif mode == 6:\n        return img.rot90(2, [2, 3])\n    elif mode == 7:\n        return img.rot90(3, [2, 3]).flip([2])\n\n\ndef augment_img_tensor(img, mode=0):\n    '''Kai Zhang (github: https://github.com/cszn)\n    '''\n    img_size = img.size()\n    img_np = img.data.cpu().numpy()\n    if len(img_size) == 3:\n        img_np = np.transpose(img_np, (1, 2, 0))\n    elif len(img_size) == 4:\n        img_np = np.transpose(img_np, (2, 3, 1, 0))\n    img_np = augment_img(img_np, mode=mode)\n    img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))\n    if len(img_size) == 3:\n        img_tensor = img_tensor.permute(2, 0, 1)\n    elif len(img_size) == 4:\n        img_tensor = img_tensor.permute(3, 2, 0, 1)\n\n    return img_tensor.type_as(img)\n\n\ndef augment_img_np3(img, mode=0):\n    if mode == 0:\n        return img\n    elif mode == 1:\n        return img.transpose(1, 0, 2)\n    elif mode == 2:\n        return img[::-1, :, :]\n    elif mode == 3:\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 4:\n        return img[:, ::-1, :]\n    elif mode == 5:\n        img = img[:, ::-1, :]\n        img = img.transpose(1, 0, 2)\n        return img\n    elif mode == 6:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        return img\n    elif mode == 7:\n        img = img[:, ::-1, :]\n        img = img[::-1, :, :]\n        img = img.transpose(1, 0, 2)\n        return img\n\n\ndef augment_imgs(img_list, hflip=True, rot=True):\n    # horizontal flip OR rotate\n    hflip = hflip and random.random() < 0.5\n    vflip = rot and random.random() < 0.5\n    rot90 = rot and random.random() < 0.5\n\n    def _augment(img):\n        if hflip:\n            img = img[:, ::-1, :]\n        if vflip:\n            img = img[::-1, :, :]\n        if rot90:\n            img = img.transpose(1, 0, 2)\n        return img\n\n    return [_augment(img) for img in img_list]\n\n\n'''\n# --------------------------------------------\n# modcrop and shave\n# --------------------------------------------\n'''\n\n\ndef modcrop(img_in, scale):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    if img.ndim == 2:\n        H, W = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[:H - H_r, :W - W_r]\n    elif img.ndim == 3:\n        H, W, C = img.shape\n        H_r, W_r = H % scale, W % scale\n        img = img[:H - H_r, :W - W_r, :]\n    else:\n        raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))\n    return img\n\n\ndef shave(img_in, border=0):\n    # img_in: Numpy, HWC or HW\n    img = np.copy(img_in)\n    h, w = img.shape[:2]\n    img = img[border:h-border, border:w-border]\n    return img\n\n\n'''\n# --------------------------------------------\n# image processing process on numpy image\n# channel_convert(in_c, tar_type, img_list):\n# rgb2ycbcr(img, only_y=True):\n# bgr2ycbcr(img, only_y=True):\n# ycbcr2rgb(img):\n# --------------------------------------------\n'''\n\n\ndef rgb2ycbcr(img, only_y=True):\n    '''same as matlab rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    if only_y:\n        rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],\n                              [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef ycbcr2rgb(img):\n    '''same as matlab ycbcr2rgb\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],\n                          [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef bgr2ycbcr(img, only_y=True):\n    '''bgr version of rgb2ycbcr\n    only_y: only return Y channel\n    Input:\n        uint8, [0, 255]\n        float, [0, 1]\n    '''\n    in_img_type = img.dtype\n    img.astype(np.float32)\n    if in_img_type != np.uint8:\n        img *= 255.\n    # convert\n    if only_y:\n        rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0\n    else:\n        rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],\n                              [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]\n    if in_img_type == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(in_img_type)\n\n\ndef channel_convert(in_c, tar_type, img_list):\n    # conversion among BGR, gray and y\n    if in_c == 3 and tar_type == 'gray':  # BGR to gray\n        gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in gray_list]\n    elif in_c == 3 and tar_type == 'y':  # BGR to y\n        y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]\n        return [np.expand_dims(img, axis=2) for img in y_list]\n    elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR\n        return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]\n    else:\n        return img_list\n\n\n'''\n# --------------------------------------------\n# metric, PSNR and SSIM\n# --------------------------------------------\n'''\n\n\n# --------------------------------------------\n# PSNR\n# --------------------------------------------\ndef calculate_psnr(img1, img2, border=0):\n    # img1 and img2 have range [0, 255]\n    #img1 = img1.squeeze()\n    #img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border:h-border, border:w-border]\n    img2 = img2[border:h-border, border:w-border]\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    mse = np.mean((img1 - img2)**2)\n    if mse == 0:\n        return float('inf')\n    return 20 * math.log10(255.0 / math.sqrt(mse))\n\n\n# --------------------------------------------\n# SSIM\n# --------------------------------------------\ndef calculate_ssim(img1, img2, border=0):\n    '''calculate SSIM\n    the same outputs as MATLAB's\n    img1, img2: [0, 255]\n    '''\n    #img1 = img1.squeeze()\n    #img2 = img2.squeeze()\n    if not img1.shape == img2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n    h, w = img1.shape[:2]\n    img1 = img1[border:h-border, border:w-border]\n    img2 = img2[border:h-border, border:w-border]\n\n    if img1.ndim == 2:\n        return ssim(img1, img2)\n    elif img1.ndim == 3:\n        if img1.shape[2] == 3:\n            ssims = []\n            for i in range(3):\n                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))\n            return np.array(ssims).mean()\n        elif img1.shape[2] == 1:\n            return ssim(np.squeeze(img1), np.squeeze(img2))\n    else:\n        raise ValueError('Wrong input image dimensions.')\n\n\ndef ssim(img1, img2):\n    C1 = (0.01 * 255)**2\n    C2 = (0.03 * 255)**2\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    kernel = cv2.getGaussianKernel(11, 1.5)\n    window = np.outer(kernel, kernel.transpose())\n\n    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid\n    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]\n    mu1_sq = mu1**2\n    mu2_sq = mu2**2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq\n    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq\n    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *\n                                                            (sigma1_sq + sigma2_sq + C2))\n    return ssim_map.mean()\n\n\n'''\n# --------------------------------------------\n# matlab's bicubic imresize (numpy and torch) [0, 1]\n# --------------------------------------------\n'''\n\n\n# matlab 'imresize' function, now only support 'bicubic'\ndef cubic(x):\n    absx = torch.abs(x)\n    absx2 = absx**2\n    absx3 = absx**3\n    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \\\n        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))\n\n\ndef calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):\n    if (scale < 1) and (antialiasing):\n        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width\n        kernel_width = kernel_width / scale\n\n    # Output-space coordinates\n    x = torch.linspace(1, out_length, out_length)\n\n    # Input-space coordinates. Calculate the inverse mapping such that 0.5\n    # in output space maps to 0.5 in input space, and 0.5+scale in output\n    # space maps to 1.5 in input space.\n    u = x / scale + 0.5 * (1 - 1 / scale)\n\n    # What is the left-most pixel that can be involved in the computation?\n    left = torch.floor(u - kernel_width / 2)\n\n    # What is the maximum number of pixels that can be involved in the\n    # computation?  Note: it's OK to use an extra pixel here; if the\n    # corresponding weights are all zero, it will be eliminated at the end\n    # of this function.\n    P = math.ceil(kernel_width) + 2\n\n    # The indices of the input pixels involved in computing the k-th output\n    # pixel are in row k of the indices matrix.\n    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(\n        1, P).expand(out_length, P)\n\n    # The weights used to compute the k-th output pixel are in row k of the\n    # weights matrix.\n    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices\n    # apply cubic kernel\n    if (scale < 1) and (antialiasing):\n        weights = scale * cubic(distance_to_center * scale)\n    else:\n        weights = cubic(distance_to_center)\n    # Normalize the weights matrix so that each row sums to 1.\n    weights_sum = torch.sum(weights, 1).view(out_length, 1)\n    weights = weights / weights_sum.expand(out_length, P)\n\n    # If a column in weights is all zero, get rid of it. only consider the first and last column.\n    weights_zero_tmp = torch.sum((weights == 0), 0)\n    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 1, P - 2)\n        weights = weights.narrow(1, 1, P - 2)\n    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 0, P - 2)\n        weights = weights.narrow(1, 0, P - 2)\n    weights = weights.contiguous()\n    indices = indices.contiguous()\n    sym_len_s = -indices.min() + 1\n    sym_len_e = indices.max() - in_length\n    indices = indices + sym_len_s - 1\n    return weights, indices, int(sym_len_s), int(sym_len_e)\n\n\n# --------------------------------------------\n# imresize for tensor image [0, 1]\n# --------------------------------------------\ndef imresize(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: pytorch tensor, CHW or HW [0,1]\n    # output: CHW or HW [0,1] w/o round\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(0)\n    in_C, in_H, in_W = img.size()\n    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing)\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing)\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)\n    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:, :sym_len_Hs, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[:, -sym_len_He:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(in_C, out_H, in_W)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)\n    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :, :sym_len_Ws]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, :, -sym_len_We:]\n    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(2, inv_idx)\n    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(in_C, out_H, out_W)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])\n    if need_squeeze:\n        out_2.squeeze_()\n    return out_2\n\n\n# --------------------------------------------\n# imresize for numpy image [0, 1]\n# --------------------------------------------\ndef imresize_np(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: Numpy, HWC or HW [0,1]\n    # output: HWC or HW [0,1] w/o round\n    img = torch.from_numpy(img)\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(2)\n\n    in_H, in_W, in_C = img.size()\n    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing)\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing)\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)\n    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:sym_len_Hs, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[-sym_len_He:, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(out_H, in_W, in_C)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)\n    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :sym_len_Ws, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, -sym_len_We:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(out_H, out_W, in_C)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])\n    if need_squeeze:\n        out_2.squeeze_()\n\n    return out_2.numpy()\n\n\nif __name__ == '__main__':\n    print('---')\n#    img = imread_uint('test.bmp', 3)\n#    img = uint2single(img)\n#    img_bicubic = imresize_np(img, 1/4)"
  },
  {
    "path": "ldm/modules/losses/__init__.py",
    "content": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
  },
  {
    "path": "ldm/modules/losses/contperceptual.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import *  # TODO: taming dependency yes/no?\n\n\nclass LPIPSWithDiscriminator(nn.Module):\n    def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_loss=\"hinge\"):\n\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.kl_weight = kl_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        # output log variance\n        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = hinge_d_loss if disc_loss == \"hinge\" else vanilla_d_loss\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(self, inputs, reconstructions, posteriors, optimizer_idx,\n                global_step, last_layer=None, cond=None, split=\"train\",\n                weights=None):\n        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n\n        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n        weighted_nll_loss = nll_loss\n        if weights is not None:\n            weighted_nll_loss = weights*nll_loss\n        weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]\n        nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        kl_loss = posteriors.kl()\n        kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            if self.disc_factor > 0.0:\n                try:\n                    d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n                except RuntimeError:\n                    assert not self.training\n                    d_weight = torch.tensor(0.0)\n            else:\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss\n\n            log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(), \"{}/logvar\".format(split): self.logvar.detach(),\n                   \"{}/kl_loss\".format(split): kl_loss.detach().mean(), \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n\n"
  },
  {
    "path": "ldm/modules/losses/vqperceptual.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discriminator.model import NLayerDiscriminator, weights_init\nfrom taming.modules.losses.lpips import LPIPS\nfrom taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss\n\n\ndef hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):\n    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]\n    loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])\n    loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])\n    loss_real = (weights * loss_real).sum() / weights.sum()\n    loss_fake = (weights * loss_fake).sum() / weights.sum()\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef measure_perplexity(predicted_indices, n_embed):\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n\ndef l1(x, y):\n    return torch.abs(x-y)\n\n\ndef l2(x, y):\n    return torch.pow((x-y), 2)\n\n\nclass VQLPIPSWithDiscriminator(nn.Module):\n    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_ndf=64, disc_loss=\"hinge\", n_classes=None, perceptual_loss=\"lpips\",\n                 pixel_loss=\"l1\"):\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        assert perceptual_loss in [\"lpips\", \"clips\", \"dists\"]\n        assert pixel_loss in [\"l1\", \"l2\"]\n        self.codebook_weight = codebook_weight\n        self.pixel_weight = pixelloss_weight\n        if perceptual_loss == \"lpips\":\n            print(f\"{self.__class__.__name__}: Running with LPIPS.\")\n            self.perceptual_loss = LPIPS().eval()\n        else:\n            raise ValueError(f\"Unknown perceptual loss: >> {perceptual_loss} <<\")\n        self.perceptual_weight = perceptual_weight\n\n        if pixel_loss == \"l1\":\n            self.pixel_loss = l1\n        else:\n            self.pixel_loss = l2\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm,\n                                                 ndf=disc_ndf\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        if disc_loss == \"hinge\":\n            self.disc_loss = hinge_d_loss\n        elif disc_loss == \"vanilla\":\n            self.disc_loss = vanilla_d_loss\n        else:\n            raise ValueError(f\"Unknown GAN loss '{disc_loss}'.\")\n        print(f\"VQLPIPSWithDiscriminator running with {disc_loss} loss.\")\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n        self.n_classes = n_classes\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,\n                global_step, last_layer=None, cond=None, split=\"train\", predicted_indices=None):\n        if not exists(codebook_loss):\n            codebook_loss = torch.tensor([0.]).to(inputs.device)\n        #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n        else:\n            p_loss = torch.tensor([0.0])\n\n        nll_loss = rec_loss\n        #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        nll_loss = torch.mean(nll_loss)\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            try:\n                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n            except RuntimeError:\n                assert not self.training\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()\n\n            log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(),\n                   \"{}/quant_loss\".format(split): codebook_loss.detach().mean(),\n                   \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/p_loss\".format(split): p_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n            if predicted_indices is not None:\n                assert self.n_classes is not None\n                with torch.no_grad():\n                    perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)\n                log[f\"{split}/perplexity\"] = perplexity\n                log[f\"{split}/cluster_usage\"] = cluster_usage\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n"
  },
  {
    "path": "ldm/modules/x_transformer.py",
    "content": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom einops import rearrange, repeat, reduce\n\n# constants\n\nDEFAULT_DIM_HEAD = 64\n\nIntermediates = namedtuple('Intermediates', [\n    'pre_softmax_attn',\n    'post_softmax_attn'\n])\n\nLayerIntermediates = namedtuple('Intermediates', [\n    'hiddens',\n    'attn_intermediates'\n])\n\n\nclass AbsolutePositionalEmbedding(nn.Module):\n    def __init__(self, dim, max_seq_len):\n        super().__init__()\n        self.emb = nn.Embedding(max_seq_len, dim)\n        self.init_()\n\n    def init_(self):\n        nn.init.normal_(self.emb.weight, std=0.02)\n\n    def forward(self, x):\n        n = torch.arange(x.shape[1], device=x.device)\n        return self.emb(n)[None, :, :]\n\n\nclass FixedPositionalEmbedding(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer('inv_freq', inv_freq)\n\n    def forward(self, x, seq_dim=1, offset=0):\n        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset\n        sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)\n        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)\n        return emb[None, :, :]\n\n\n# helpers\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef always(val):\n    def inner(*args, **kwargs):\n        return val\n    return inner\n\n\ndef not_equals(val):\n    def inner(x):\n        return x != val\n    return inner\n\n\ndef equals(val):\n    def inner(x):\n        return x == val\n    return inner\n\n\ndef max_neg_value(tensor):\n    return -torch.finfo(tensor.dtype).max\n\n\n# keyword argument helpers\n\ndef pick_and_pop(keys, d):\n    values = list(map(lambda key: d.pop(key), keys))\n    return dict(zip(keys, values))\n\n\ndef group_dict_by_key(cond, d):\n    return_val = [dict(), dict()]\n    for key in d.keys():\n        match = bool(cond(key))\n        ind = int(not match)\n        return_val[ind][key] = d[key]\n    return (*return_val,)\n\n\ndef string_begins_with(prefix, str):\n    return str.startswith(prefix)\n\n\ndef group_by_key_prefix(prefix, d):\n    return group_dict_by_key(partial(string_begins_with, prefix), d)\n\n\ndef groupby_prefix_and_trim(prefix, d):\n    kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)\n    kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))\n    return kwargs_without_prefix, kwargs\n\n\n# classes\nclass Scale(nn.Module):\n    def __init__(self, value, fn):\n        super().__init__()\n        self.value = value\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.value, *rest)\n\n\nclass Rezero(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n        self.g = nn.Parameter(torch.zeros(1))\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.g, *rest)\n\n\nclass ScaleNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.scale = dim ** -0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(1))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-8):\n        super().__init__()\n        self.scale = dim ** -0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass Residual(nn.Module):\n    def forward(self, x, residual):\n        return x + residual\n\n\nclass GRUGating(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.gru = nn.GRUCell(dim, dim)\n\n    def forward(self, x, residual):\n        gated_output = self.gru(\n            rearrange(x, 'b n d -> (b n) d'),\n            rearrange(residual, 'b n d -> (b n) d')\n        )\n\n        return gated_output.reshape_as(x)\n\n\n# feedforward\n\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = nn.Sequential(\n            nn.Linear(dim, inner_dim),\n            nn.GELU()\n        ) if not glu else GEGLU(dim, inner_dim)\n\n        self.net = nn.Sequential(\n            project_in,\n            nn.Dropout(dropout),\n            nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\n# attention.\nclass Attention(nn.Module):\n    def __init__(\n            self,\n            dim,\n            dim_head=DEFAULT_DIM_HEAD,\n            heads=8,\n            causal=False,\n            mask=None,\n            talking_heads=False,\n            sparse_topk=None,\n            use_entmax15=False,\n            num_mem_kv=0,\n            dropout=0.,\n            on_attn=False\n    ):\n        super().__init__()\n        if use_entmax15:\n            raise NotImplementedError(\"Check out entmax activation instead of softmax activation!\")\n        self.scale = dim_head ** -0.5\n        self.heads = heads\n        self.causal = causal\n        self.mask = mask\n\n        inner_dim = dim_head * heads\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(dim, inner_dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n        # talking heads\n        self.talking_heads = talking_heads\n        if talking_heads:\n            self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n            self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n\n        # explicit topk sparse attention\n        self.sparse_topk = sparse_topk\n\n        # entmax\n        #self.attn_fn = entmax15 if use_entmax15 else F.softmax\n        self.attn_fn = F.softmax\n\n        # add memory key / values\n        self.num_mem_kv = num_mem_kv\n        if num_mem_kv > 0:\n            self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n            self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n\n        # attention on attention\n        self.attn_on_attn = on_attn\n        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)\n\n    def forward(\n            self,\n            x,\n            context=None,\n            mask=None,\n            context_mask=None,\n            rel_pos=None,\n            sinusoidal_emb=None,\n            prev_attn=None,\n            mem=None\n    ):\n        b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device\n        kv_input = default(context, x)\n\n        q_input = x\n        k_input = kv_input\n        v_input = kv_input\n\n        if exists(mem):\n            k_input = torch.cat((mem, k_input), dim=-2)\n            v_input = torch.cat((mem, v_input), dim=-2)\n\n        if exists(sinusoidal_emb):\n            # in shortformer, the query would start at a position offset depending on the past cached memory\n            offset = k_input.shape[-2] - q_input.shape[-2]\n            q_input = q_input + sinusoidal_emb(q_input, offset=offset)\n            k_input = k_input + sinusoidal_emb(k_input)\n\n        q = self.to_q(q_input)\n        k = self.to_k(k_input)\n        v = self.to_v(v_input)\n\n        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))\n\n        input_mask = None\n        if any(map(exists, (mask, context_mask))):\n            q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())\n            k_mask = q_mask if not exists(context) else context_mask\n            k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())\n            q_mask = rearrange(q_mask, 'b i -> b () i ()')\n            k_mask = rearrange(k_mask, 'b j -> b () () j')\n            input_mask = q_mask * k_mask\n\n        if self.num_mem_kv > 0:\n            mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))\n            k = torch.cat((mem_k, k), dim=-2)\n            v = torch.cat((mem_v, v), dim=-2)\n            if exists(input_mask):\n                input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)\n\n        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale\n        mask_value = max_neg_value(dots)\n\n        if exists(prev_attn):\n            dots = dots + prev_attn\n\n        pre_softmax_attn = dots\n\n        if talking_heads:\n            dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()\n\n        if exists(rel_pos):\n            dots = rel_pos(dots)\n\n        if exists(input_mask):\n            dots.masked_fill_(~input_mask, mask_value)\n            del input_mask\n\n        if self.causal:\n            i, j = dots.shape[-2:]\n            r = torch.arange(i, device=device)\n            mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')\n            mask = F.pad(mask, (j - i, 0), value=False)\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:\n            top, _ = dots.topk(self.sparse_topk, dim=-1)\n            vk = top[..., -1].unsqueeze(-1).expand_as(dots)\n            mask = dots < vk\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        attn = self.attn_fn(dots, dim=-1)\n        post_softmax_attn = attn\n\n        attn = self.dropout(attn)\n\n        if talking_heads:\n            attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()\n\n        out = einsum('b h i j, b h j d -> b h i d', attn, v)\n        out = rearrange(out, 'b h n d -> b n (h d)')\n\n        intermediates = Intermediates(\n            pre_softmax_attn=pre_softmax_attn,\n            post_softmax_attn=post_softmax_attn\n        )\n\n        return self.to_out(out), intermediates\n\n\nclass AttentionLayers(nn.Module):\n    def __init__(\n            self,\n            dim,\n            depth,\n            heads=8,\n            causal=False,\n            cross_attend=False,\n            only_cross=False,\n            use_scalenorm=False,\n            use_rmsnorm=False,\n            use_rezero=False,\n            rel_pos_num_buckets=32,\n            rel_pos_max_distance=128,\n            position_infused_attn=False,\n            custom_layers=None,\n            sandwich_coef=None,\n            par_ratio=None,\n            residual_attn=False,\n            cross_residual_attn=False,\n            macaron=False,\n            pre_norm=True,\n            gate_residual=False,\n            **kwargs\n    ):\n        super().__init__()\n        ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)\n        attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)\n\n        dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)\n\n        self.dim = dim\n        self.depth = depth\n        self.layers = nn.ModuleList([])\n\n        self.has_pos_emb = position_infused_attn\n        self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None\n        self.rotary_pos_emb = always(None)\n\n        assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'\n        self.rel_pos = None\n\n        self.pre_norm = pre_norm\n\n        self.residual_attn = residual_attn\n        self.cross_residual_attn = cross_residual_attn\n\n        norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm\n        norm_class = RMSNorm if use_rmsnorm else norm_class\n        norm_fn = partial(norm_class, dim)\n\n        norm_fn = nn.Identity if use_rezero else norm_fn\n        branch_fn = Rezero if use_rezero else None\n\n        if cross_attend and not only_cross:\n            default_block = ('a', 'c', 'f')\n        elif cross_attend and only_cross:\n            default_block = ('c', 'f')\n        else:\n            default_block = ('a', 'f')\n\n        if macaron:\n            default_block = ('f',) + default_block\n\n        if exists(custom_layers):\n            layer_types = custom_layers\n        elif exists(par_ratio):\n            par_depth = depth * len(default_block)\n            assert 1 < par_ratio <= par_depth, 'par ratio out of range'\n            default_block = tuple(filter(not_equals('f'), default_block))\n            par_attn = par_depth // par_ratio\n            depth_cut = par_depth * 2 // 3  # 2 / 3 attention layer cutoff suggested by PAR paper\n            par_width = (depth_cut + depth_cut // par_attn) // par_attn\n            assert len(default_block) <= par_width, 'default block is too large for par_ratio'\n            par_block = default_block + ('f',) * (par_width - len(default_block))\n            par_head = par_block * par_attn\n            layer_types = par_head + ('f',) * (par_depth - len(par_head))\n        elif exists(sandwich_coef):\n            assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'\n            layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef\n        else:\n            layer_types = default_block * depth\n\n        self.layer_types = layer_types\n        self.num_attn_layers = len(list(filter(equals('a'), layer_types)))\n\n        for layer_type in self.layer_types:\n            if layer_type == 'a':\n                layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)\n            elif layer_type == 'c':\n                layer = Attention(dim, heads=heads, **attn_kwargs)\n            elif layer_type == 'f':\n                layer = FeedForward(dim, **ff_kwargs)\n                layer = layer if not macaron else Scale(0.5, layer)\n            else:\n                raise Exception(f'invalid layer type {layer_type}')\n\n            if isinstance(layer, Attention) and exists(branch_fn):\n                layer = branch_fn(layer)\n\n            if gate_residual:\n                residual_fn = GRUGating(dim)\n            else:\n                residual_fn = Residual()\n\n            self.layers.append(nn.ModuleList([\n                norm_fn(),\n                layer,\n                residual_fn\n            ]))\n\n    def forward(\n            self,\n            x,\n            context=None,\n            mask=None,\n            context_mask=None,\n            mems=None,\n            return_hiddens=False,\n            **kwargs\n    ):\n        hiddens = []\n        intermediates = []\n        prev_attn = None\n        prev_cross_attn = None\n\n        mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers\n\n        for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):\n            is_last = ind == (len(self.layers) - 1)\n\n            if layer_type == 'a':\n                hiddens.append(x)\n                layer_mem = mems.pop(0)\n\n            residual = x\n\n            if self.pre_norm:\n                x = norm(x)\n\n            if layer_type == 'a':\n                out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,\n                                   prev_attn=prev_attn, mem=layer_mem)\n            elif layer_type == 'c':\n                out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)\n            elif layer_type == 'f':\n                out = block(x)\n\n            x = residual_fn(out, residual)\n\n            if layer_type in ('a', 'c'):\n                intermediates.append(inter)\n\n            if layer_type == 'a' and self.residual_attn:\n                prev_attn = inter.pre_softmax_attn\n            elif layer_type == 'c' and self.cross_residual_attn:\n                prev_cross_attn = inter.pre_softmax_attn\n\n            if not self.pre_norm and not is_last:\n                x = norm(x)\n\n        if return_hiddens:\n            intermediates = LayerIntermediates(\n                hiddens=hiddens,\n                attn_intermediates=intermediates\n            )\n\n            return x, intermediates\n\n        return x\n\n\nclass Encoder(AttentionLayers):\n    def __init__(self, **kwargs):\n        assert 'causal' not in kwargs, 'cannot set causality on encoder'\n        super().__init__(causal=False, **kwargs)\n\n\n\nclass TransformerWrapper(nn.Module):\n    def __init__(\n            self,\n            *,\n            num_tokens,\n            max_seq_len,\n            attn_layers,\n            emb_dim=None,\n            max_mem_len=0.,\n            emb_dropout=0.,\n            num_memory_tokens=None,\n            tie_embedding=False,\n            use_pos_emb=True\n    ):\n        super().__init__()\n        assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'\n\n        dim = attn_layers.dim\n        emb_dim = default(emb_dim, dim)\n\n        self.max_seq_len = max_seq_len\n        self.max_mem_len = max_mem_len\n        self.num_tokens = num_tokens\n\n        self.token_emb = nn.Embedding(num_tokens, emb_dim)\n        self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (\n                    use_pos_emb and not attn_layers.has_pos_emb) else always(0)\n        self.emb_dropout = nn.Dropout(emb_dropout)\n\n        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()\n        self.attn_layers = attn_layers\n        self.norm = nn.LayerNorm(dim)\n\n        self.init_()\n\n        self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()\n\n        # memory tokens (like [cls]) from Memory Transformers paper\n        num_memory_tokens = default(num_memory_tokens, 0)\n        self.num_memory_tokens = num_memory_tokens\n        if num_memory_tokens > 0:\n            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))\n\n            # let funnel encoder know number of memory tokens, if specified\n            if hasattr(attn_layers, 'num_memory_tokens'):\n                attn_layers.num_memory_tokens = num_memory_tokens\n\n    def init_(self):\n        nn.init.normal_(self.token_emb.weight, std=0.02)\n\n    def forward(\n            self,\n            x,\n            return_embeddings=False,\n            mask=None,\n            return_mems=False,\n            return_attn=False,\n            mems=None,\n            embedding_manager=None,\n            **kwargs\n    ):\n        b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens\n\n        embedded_x = self.token_emb(x)\n        \n        if embedding_manager:\n            x = embedding_manager(x, embedded_x)\n        else:\n            x = embedded_x\n\n        x = x + self.pos_emb(x)\n        x = self.emb_dropout(x)\n\n        x = self.project_emb(x)\n\n        if num_mem > 0:\n            mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)\n            x = torch.cat((mem, x), dim=1)\n\n            # auto-handle masking after appending memory tokens\n            if exists(mask):\n                mask = F.pad(mask, (num_mem, 0), value=True)\n\n        x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)\n        x = self.norm(x)\n\n        mem, x = x[:, :num_mem], x[:, num_mem:]\n\n        out = self.to_logits(x) if not return_embeddings else x\n\n        if return_mems:\n            hiddens = intermediates.hiddens\n            new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens\n            new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))\n            return out, new_mems\n\n        if return_attn:\n            attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))\n            return out, attn_maps\n\n        return out\n\n"
  },
  {
    "path": "ldm/util.py",
    "content": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functools import partial\n\nimport multiprocessing as mp\nfrom threading import Thread\nfrom queue import Queue\n\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)\n        nc = int(40 * (wh[0] / 256))\n        lines = \"\\n\".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config, **kwargs):\n    if not \"target\" in config:\n        if config == '__is_first_stage__':\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()), **kwargs)\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):\n    # create dummy dataset instance\n\n    # run prefetching\n    if idx_to_fn:\n        res = func(data, worker_id=idx)\n    else:\n        res = func(data)\n    Q.put([idx, res])\n    Q.put(\"Done\")\n\n\ndef parallel_data_prefetch(\n        func: callable, data, n_proc, target_data_type=\"ndarray\", cpu_intensive=True, use_worker_id=False\n):\n    # if target_data_type not in [\"ndarray\", \"list\"]:\n    #     raise ValueError(\n    #         \"Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray.\"\n    #     )\n    if isinstance(data, np.ndarray) and target_data_type == \"list\":\n        raise ValueError(\"list expected but function got ndarray.\")\n    elif isinstance(data, abc.Iterable):\n        if isinstance(data, dict):\n            print(\n                f'WARNING:\"data\" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'\n            )\n            data = list(data.values())\n        if target_data_type == \"ndarray\":\n            data = np.asarray(data)\n        else:\n            data = list(data)\n    else:\n        raise TypeError(\n            f\"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.\"\n        )\n\n    if cpu_intensive:\n        Q = mp.Queue(1000)\n        proc = mp.Process\n    else:\n        Q = Queue(1000)\n        proc = Thread\n    # spawn processes\n    if target_data_type == \"ndarray\":\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(np.array_split(data, n_proc))\n        ]\n    else:\n        step = (\n            int(len(data) / n_proc + 1)\n            if len(data) % n_proc != 0\n            else int(len(data) / n_proc)\n        )\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(\n                [data[i: i + step] for i in range(0, len(data), step)]\n            )\n        ]\n    processes = []\n    for i in range(n_proc):\n        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])\n        processes += [p]\n\n    # start processes\n    print(f\"Start prefetching...\")\n    import time\n\n    start = time.time()\n    gather_res = [[] for _ in range(n_proc)]\n    try:\n        for p in processes:\n            p.start()\n\n        k = 0\n        while k < n_proc:\n            # get result\n            res = Q.get()\n            if res == \"Done\":\n                k += 1\n            else:\n                gather_res[res[0]] = res[1]\n\n    except Exception as e:\n        print(\"Exception: \", e)\n        for p in processes:\n            p.terminate()\n\n        raise e\n    finally:\n        for p in processes:\n            p.join()\n        print(f\"Prefetching complete. [{time.time() - start} sec.]\")\n\n    if target_data_type == 'ndarray':\n        if not isinstance(gather_res[0], np.ndarray):\n            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)\n\n        # order outputs\n        return np.concatenate(gather_res, axis=0)\n    elif target_data_type == 'list':\n        out = []\n        for r in gather_res:\n            out.extend(r)\n        return out\n    else:\n        return gather_res\n"
  },
  {
    "path": "main.py",
    "content": "import argparse, os, sys, datetime, glob, importlib, csv\nimport numpy as np\nimport time\nimport torch\n\nimport torchvision\nimport pytorch_lightning as pl\n\nfrom packaging import version\nfrom omegaconf import OmegaConf\nfrom torch.utils.data import random_split, DataLoader, Dataset, Subset\nfrom functools import partial\nfrom PIL import Image\n\nfrom pytorch_lightning import seed_everything\nfrom pytorch_lightning.trainer import Trainer\nfrom pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor\nfrom pytorch_lightning.utilities.distributed import rank_zero_only\nfrom pytorch_lightning.utilities import rank_zero_info\n\nfrom ldm.data.base import Txt2ImgIterableBaseDataset\nfrom ldm.util import instantiate_from_config\n\ndef load_model_from_config(config, ckpt, verbose=False):\n    print(f\"Loading model from {ckpt}\")\n    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n    sd = pl_sd[\"state_dict\"]\n    config.model.params.ckpt_path = ckpt\n    model = instantiate_from_config(config.model)\n    m, u = model.load_state_dict(sd, strict=False)\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    model.cuda()\n    return model\n\ndef get_parser(**parser_kwargs):\n    def str2bool(v):\n        if isinstance(v, bool):\n            return v\n        if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n            return True\n        elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n            return False\n        else:\n            raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\n    parser = argparse.ArgumentParser(**parser_kwargs)\n    parser.add_argument(\n        \"-n\",\n        \"--name\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"postfix for logdir\",\n    )\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        const=True,\n        default=\"\",\n        nargs=\"?\",\n        help=\"resume from logdir or checkpoint in logdir\",\n    )\n    parser.add_argument(\n        \"-b\",\n        \"--base\",\n        nargs=\"*\",\n        metavar=\"base_config.yaml\",\n        help=\"paths to base configs. Loaded from left-to-right. \"\n             \"Parameters can be overwritten or added with command-line options of the form `--key value`.\",\n        default=list(),\n    )\n    parser.add_argument(\n        \"-t\",\n        \"--train\",\n        type=str2bool,\n        const=True,\n        default=False,\n        nargs=\"?\",\n        help=\"train\",\n    )\n    parser.add_argument(\n        \"--no-test\",\n        type=str2bool,\n        const=True,\n        default=False,\n        nargs=\"?\",\n        help=\"disable test\",\n    )\n    parser.add_argument(\n        \"-p\",\n        \"--project\",\n        help=\"name of new or path to existing project\"\n    )\n    parser.add_argument(\n        \"-d\",\n        \"--debug\",\n        type=str2bool,\n        nargs=\"?\",\n        const=True,\n        default=False,\n        help=\"enable post-mortem debugging\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=23,\n        help=\"seed for seed_everything\",\n    )\n    parser.add_argument(\n        \"-f\",\n        \"--postfix\",\n        type=str,\n        default=\"\",\n        help=\"post-postfix for default name\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--logdir\",\n        type=str,\n        default=\"logs\",\n        help=\"directory for logging dat shit\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        type=str2bool,\n        nargs=\"?\",\n        const=False,\n        default=False,\n        help=\"scale base-lr by ngpu * batch_size * n_accumulate\",\n    )\n\n    parser.add_argument(\n        \"--datadir_in_name\", \n        type=str2bool, \n        nargs=\"?\", \n        const=True, \n        default=True, \n        help=\"Prepend the final directory in the data_root to the output directory name\")\n\n    parser.add_argument(\"--actual_resume\", \n        type=str,\n        required=True,\n        help=\"Path to model to actually resume from\")\n\n    parser.add_argument(\"--data_root\", \n        type=str, \n        required=True, \n        help=\"Path to directory with training images\")\n    \n    parser.add_argument(\"--reg_data_root\", \n        type=str, \n        required=True, \n        help=\"Path to directory with regularization images\")\n\n    parser.add_argument(\"--embedding_manager_ckpt\", \n        type=str, \n        default=\"\", \n        help=\"Initialize embedding manager from a checkpoint\")\n\n    parser.add_argument(\"--class_word\", \n        type=str, \n        default=\"dog\",\n        help=\"Placeholder token which will be used to denote the concept in future prompts\")\n\n    parser.add_argument(\"--init_word\", \n        type=str, \n        help=\"Word to use as source for initial token embedding\")\n\n    return parser\n\n\ndef nondefault_trainer_args(opt):\n    parser = argparse.ArgumentParser()\n    parser = Trainer.add_argparse_args(parser)\n    args = parser.parse_args([])\n    return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))\n\n\nclass WrappedDataset(Dataset):\n    \"\"\"Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset\"\"\"\n\n    def __init__(self, dataset):\n        self.data = dataset\n\n    def __len__(self):\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        return self.data[idx]\n\n\ndef worker_init_fn(_):\n    worker_info = torch.utils.data.get_worker_info()\n\n    dataset = worker_info.dataset\n    worker_id = worker_info.id\n\n    if isinstance(dataset, Txt2ImgIterableBaseDataset):\n        split_size = dataset.num_records // worker_info.num_workers\n        # reset num_records to the true number to retain reliable length information\n        dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]\n        current_id = np.random.choice(len(np.random.get_state()[1]), 1)\n        return np.random.seed(np.random.get_state()[1][current_id] + worker_id)\n    else:\n        return np.random.seed(np.random.get_state()[1][0] + worker_id)\n\nclass ConcatDataset(Dataset):\n    def __init__(self, *datasets):\n        self.datasets = datasets\n\n    def __getitem__(self, idx):\n        return tuple(d[idx] for d in self.datasets)\n\n    def __len__(self):\n        return min(len(d) for d in self.datasets)\n    \nclass DataModuleFromConfig(pl.LightningDataModule):\n    def __init__(self, batch_size, train=None, reg = None, validation=None, test=None, predict=None,\n                 wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,\n                 shuffle_val_dataloader=False):\n        super().__init__()\n        self.batch_size = batch_size\n        self.dataset_configs = dict()\n        self.num_workers = num_workers if num_workers is not None else batch_size * 2\n        self.use_worker_init_fn = use_worker_init_fn\n        if train is not None:\n            self.dataset_configs[\"train\"] = train\n        if reg is not None:\n            self.dataset_configs[\"reg\"] = reg\n        \n        self.train_dataloader = self._train_dataloader\n        \n        if validation is not None:\n            self.dataset_configs[\"validation\"] = validation\n            self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)\n        if test is not None:\n            self.dataset_configs[\"test\"] = test\n            self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)\n        if predict is not None:\n            self.dataset_configs[\"predict\"] = predict\n            self.predict_dataloader = self._predict_dataloader\n        self.wrap = wrap\n\n    def prepare_data(self):\n        for data_cfg in self.dataset_configs.values():\n            instantiate_from_config(data_cfg)\n\n    def setup(self, stage=None):\n        self.datasets = dict(\n            (k, instantiate_from_config(self.dataset_configs[k]))\n            for k in self.dataset_configs)\n        if self.wrap:\n            for k in self.datasets:\n                self.datasets[k] = WrappedDataset(self.datasets[k])\n\n    def _train_dataloader(self):\n        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)\n        if is_iterable_dataset or self.use_worker_init_fn:\n            init_fn = worker_init_fn\n        else:\n            init_fn = None\n        train_set = self.datasets[\"train\"]\n        reg_set = self.datasets[\"reg\"]\n        concat_dataset = ConcatDataset(train_set, reg_set)\n        return DataLoader(concat_dataset, batch_size=self.batch_size,\n                          num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,\n                          worker_init_fn=init_fn)\n\n    def _val_dataloader(self, shuffle=False):\n        if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:\n            init_fn = worker_init_fn\n        else:\n            init_fn = None\n        return DataLoader(self.datasets[\"validation\"],\n                          batch_size=self.batch_size,\n                          num_workers=self.num_workers,\n                          worker_init_fn=init_fn,\n                          shuffle=shuffle)\n\n    def _test_dataloader(self, shuffle=False):\n        is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)\n        if is_iterable_dataset or self.use_worker_init_fn:\n            init_fn = worker_init_fn\n        else:\n            init_fn = None\n\n        # do not shuffle dataloader for iterable dataset\n        shuffle = shuffle and (not is_iterable_dataset)\n\n        return DataLoader(self.datasets[\"test\"], batch_size=self.batch_size,\n                          num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)\n\n    def _predict_dataloader(self, shuffle=False):\n        if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:\n            init_fn = worker_init_fn\n        else:\n            init_fn = None\n        return DataLoader(self.datasets[\"predict\"], batch_size=self.batch_size,\n                          num_workers=self.num_workers, worker_init_fn=init_fn)\n\n\nclass SetupCallback(Callback):\n    def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):\n        super().__init__()\n        self.resume = resume\n        self.now = now\n        self.logdir = logdir\n        self.ckptdir = ckptdir\n        self.cfgdir = cfgdir\n        self.config = config\n        self.lightning_config = lightning_config\n\n    def on_keyboard_interrupt(self, trainer, pl_module):\n        if trainer.global_rank == 0:\n            print(\"Summoning checkpoint.\")\n            ckpt_path = os.path.join(self.ckptdir, \"last.ckpt\")\n            trainer.save_checkpoint(ckpt_path)\n\n    def on_pretrain_routine_start(self, trainer, pl_module):\n        if trainer.global_rank == 0:\n            # Create logdirs and save configs\n            os.makedirs(self.logdir, exist_ok=True)\n            os.makedirs(self.ckptdir, exist_ok=True)\n            os.makedirs(self.cfgdir, exist_ok=True)\n\n            if \"callbacks\" in self.lightning_config:\n                if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:\n                    os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)\n            print(\"Project config\")\n            print(OmegaConf.to_yaml(self.config))\n            OmegaConf.save(self.config,\n                           os.path.join(self.cfgdir, \"{}-project.yaml\".format(self.now)))\n\n            print(\"Lightning config\")\n            print(OmegaConf.to_yaml(self.lightning_config))\n            OmegaConf.save(OmegaConf.create({\"lightning\": self.lightning_config}),\n                           os.path.join(self.cfgdir, \"{}-lightning.yaml\".format(self.now)))\n\n        else:\n            # ModelCheckpoint callback created log directory --- remove it\n            if not self.resume and os.path.exists(self.logdir):\n                dst, name = os.path.split(self.logdir)\n                dst = os.path.join(dst, \"child_runs\", name)\n                os.makedirs(os.path.split(dst)[0], exist_ok=True)\n                try:\n                    os.rename(self.logdir, dst)\n                except FileNotFoundError:\n                    pass\n\n\nclass ImageLogger(Callback):\n    def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,\n                 rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,\n                 log_images_kwargs=None):\n        super().__init__()\n        self.rescale = rescale\n        self.batch_freq = batch_frequency\n        self.max_images = max_images\n        self.logger_log_images = {\n            pl.loggers.TestTubeLogger: self._testtube,\n        }\n        self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]\n        if not increase_log_steps:\n            self.log_steps = [self.batch_freq]\n        self.clamp = clamp\n        self.disabled = disabled\n        self.log_on_batch_idx = log_on_batch_idx\n        self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}\n        self.log_first_step = log_first_step\n\n    @rank_zero_only\n    def _testtube(self, pl_module, images, batch_idx, split):\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k])\n            grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n\n            tag = f\"{split}/{k}\"\n            pl_module.logger.experiment.add_image(\n                tag, grid,\n                global_step=pl_module.global_step)\n\n    @rank_zero_only\n    def log_local(self, save_dir, split, images,\n                  global_step, current_epoch, batch_idx):\n        root = os.path.join(save_dir, \"images\", split)\n        for k in images:\n            grid = torchvision.utils.make_grid(images[k], nrow=4)\n            if self.rescale:\n                grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w\n            grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)\n            grid = grid.numpy()\n            grid = (grid * 255).astype(np.uint8)\n            filename = \"{}_gs-{:06}_e-{:06}_b-{:06}.jpg\".format(\n                k,\n                global_step,\n                current_epoch,\n                batch_idx)\n            path = os.path.join(root, filename)\n            os.makedirs(os.path.split(path)[0], exist_ok=True)\n            Image.fromarray(grid).save(path)\n\n    def log_img(self, pl_module, batch, batch_idx, split=\"train\"):\n        check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step\n        if (self.check_frequency(check_idx) and  # batch_idx % self.batch_freq == 0\n                hasattr(pl_module, \"log_images\") and\n                callable(pl_module.log_images) and\n                self.max_images > 0):\n            logger = type(pl_module.logger)\n\n            is_train = pl_module.training\n            if is_train:\n                pl_module.eval()\n\n            with torch.no_grad():\n                images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)\n\n            for k in images:\n                N = min(images[k].shape[0], self.max_images)\n                images[k] = images[k][:N]\n                if isinstance(images[k], torch.Tensor):\n                    images[k] = images[k].detach().cpu()\n                    if self.clamp:\n                        images[k] = torch.clamp(images[k], -1., 1.)\n\n            self.log_local(pl_module.logger.save_dir, split, images,\n                           pl_module.global_step, pl_module.current_epoch, batch_idx)\n\n            logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)\n            logger_log_images(pl_module, images, pl_module.global_step, split)\n\n            if is_train:\n                pl_module.train()\n\n    def check_frequency(self, check_idx):\n        if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (\n                check_idx > 0 or self.log_first_step):\n            try:\n                self.log_steps.pop(0)\n            except IndexError as e:\n                print(e)\n                pass\n            return True\n        return False\n\n    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):\n            self.log_img(pl_module, batch, batch_idx, split=\"train\")\n\n    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):\n        if not self.disabled and pl_module.global_step > 0:\n            self.log_img(pl_module, batch, batch_idx, split=\"val\")\n        if hasattr(pl_module, 'calibrate_grad_norm'):\n            if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:\n                self.log_gradients(trainer, pl_module, batch_idx=batch_idx)\n\n\nclass CUDACallback(Callback):\n    # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py\n    def on_train_epoch_start(self, trainer, pl_module):\n        # Reset the memory use counter\n        torch.cuda.reset_peak_memory_stats(trainer.root_gpu)\n        torch.cuda.synchronize(trainer.root_gpu)\n        self.start_time = time.time()\n\n    def on_train_epoch_end(self, trainer, pl_module):\n        torch.cuda.synchronize(trainer.root_gpu)\n        max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20\n        epoch_time = time.time() - self.start_time\n\n        try:\n            max_memory = trainer.training_type_plugin.reduce(max_memory)\n            epoch_time = trainer.training_type_plugin.reduce(epoch_time)\n\n            rank_zero_info(f\"Average Epoch time: {epoch_time:.2f} seconds\")\n            rank_zero_info(f\"Average Peak memory {max_memory:.2f}MiB\")\n        except AttributeError:\n            pass\n\nclass ModeSwapCallback(Callback):\n\n    def __init__(self, swap_step=2000):\n        super().__init__()\n        self.is_frozen = False\n        self.swap_step = swap_step\n\n    def on_train_epoch_start(self, trainer, pl_module):\n        if trainer.global_step < self.swap_step and not self.is_frozen:\n            self.is_frozen = True\n            trainer.optimizers = [pl_module.configure_opt_embedding()]\n\n        if trainer.global_step > self.swap_step and self.is_frozen:\n            self.is_frozen = False\n            trainer.optimizers = [pl_module.configure_opt_model()]\n\nif __name__ == \"__main__\":\n\n    now = datetime.datetime.now().strftime(\"%Y-%m-%dT%H-%M-%S\")\n\n    # add cwd for convenience and to make classes in this file available when\n    # running as `python main.py`\n    # (in particular `main.DataModuleFromConfig`)\n    sys.path.append(os.getcwd())\n\n    parser = get_parser()\n    parser = Trainer.add_argparse_args(parser)\n\n    opt, unknown = parser.parse_known_args()\n    if opt.name and opt.resume:\n        raise ValueError(\n            \"-n/--name and -r/--resume cannot be specified both.\"\n            \"If you want to resume training in a new log folder, \"\n            \"use -n/--name in combination with --resume_from_checkpoint\"\n        )\n    if opt.resume:\n        if not os.path.exists(opt.resume):\n            raise ValueError(\"Cannot find {}\".format(opt.resume))\n        if os.path.isfile(opt.resume):\n            paths = opt.resume.split(\"/\")\n            # idx = len(paths)-paths[::-1].index(\"logs\")+1\n            # logdir = \"/\".join(paths[:idx])\n            logdir = \"/\".join(paths[:-2])\n            ckpt = opt.resume\n        else:\n            assert os.path.isdir(opt.resume), opt.resume\n            logdir = opt.resume.rstrip(\"/\")\n            ckpt = os.path.join(logdir, \"checkpoints\", \"last.ckpt\")\n\n        opt.resume_from_checkpoint = ckpt\n        base_configs = sorted(glob.glob(os.path.join(logdir, \"configs/*.yaml\")))\n        opt.base = base_configs + opt.base\n        _tmp = logdir.split(\"/\")\n        nowname = _tmp[-1]\n    else:\n        if opt.name:\n            name = \"_\" + opt.name\n        elif opt.base:\n            cfg_fname = os.path.split(opt.base[0])[-1]\n            cfg_name = os.path.splitext(cfg_fname)[0]\n            name = \"_\" + cfg_name\n        else:\n            name = \"\"\n\n        if opt.datadir_in_name:\n            now = os.path.basename(os.path.normpath(opt.data_root)) + now\n            \n        nowname = now + name + opt.postfix\n        logdir = os.path.join(opt.logdir, nowname)\n\n    ckptdir = os.path.join(logdir, \"checkpoints\")\n    cfgdir = os.path.join(logdir, \"configs\")\n    seed_everything(opt.seed)\n\n    try:\n        # init and save configs\n        configs = [OmegaConf.load(cfg) for cfg in opt.base]\n        cli = OmegaConf.from_dotlist(unknown)\n        config = OmegaConf.merge(*configs, cli)\n        lightning_config = config.pop(\"lightning\", OmegaConf.create())\n        # merge trainer cli with config\n        trainer_config = lightning_config.get(\"trainer\", OmegaConf.create())\n        # default to ddp\n        trainer_config[\"accelerator\"] = \"ddp\"\n        for k in nondefault_trainer_args(opt):\n            trainer_config[k] = getattr(opt, k)\n        if not \"gpus\" in trainer_config:\n            del trainer_config[\"accelerator\"]\n            cpu = True\n        else:\n            gpuinfo = trainer_config[\"gpus\"]\n            print(f\"Running on GPUs {gpuinfo}\")\n            cpu = False\n        trainer_opt = argparse.Namespace(**trainer_config)\n        lightning_config.trainer = trainer_config\n\n        # model\n\n        # config.model.params.personalization_config.params.init_word = opt.init_word\n        # config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt\n        # config.model.params.personalization_config.params.placeholder_tokens = opt.placeholder_tokens\n\n        # if opt.init_word:\n        #     config.model.params.personalization_config.params.initializer_words[0] = opt.init_word\n            \n        config.data.params.train.params.placeholder_token = opt.class_word\n        config.data.params.reg.params.placeholder_token = opt.class_word\n        config.data.params.validation.params.placeholder_token = opt.class_word\n\n        if opt.actual_resume:\n            model = load_model_from_config(config, opt.actual_resume)\n        else:\n            model = instantiate_from_config(config.model)\n\n        # trainer and callbacks\n        trainer_kwargs = dict()\n\n        # default logger configs\n        default_logger_cfgs = {\n            \"wandb\": {\n                \"target\": \"pytorch_lightning.loggers.WandbLogger\",\n                \"params\": {\n                    \"name\": nowname,\n                    \"save_dir\": logdir,\n                    \"offline\": opt.debug,\n                    \"id\": nowname,\n                }\n            },\n            \"testtube\": {\n                \"target\": \"pytorch_lightning.loggers.TestTubeLogger\",\n                \"params\": {\n                    \"name\": \"testtube\",\n                    \"save_dir\": logdir,\n                }\n            },\n        }\n        default_logger_cfg = default_logger_cfgs[\"testtube\"]\n        if \"logger\" in lightning_config:\n            logger_cfg = lightning_config.logger\n        else:\n            logger_cfg = OmegaConf.create()\n        logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)\n        trainer_kwargs[\"logger\"] = instantiate_from_config(logger_cfg)\n\n        # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to\n        # specify which metric is used to determine best models\n        default_modelckpt_cfg = {\n            \"target\": \"pytorch_lightning.callbacks.ModelCheckpoint\",\n            \"params\": {\n                \"dirpath\": ckptdir,\n                \"filename\": \"{epoch:06}\",\n                \"verbose\": True,\n                \"save_last\": True,\n            }\n        }\n        if hasattr(model, \"monitor\"):\n            print(f\"Monitoring {model.monitor} as checkpoint metric.\")\n            default_modelckpt_cfg[\"params\"][\"monitor\"] = model.monitor\n            default_modelckpt_cfg[\"params\"][\"save_top_k\"] = 1\n\n        if \"modelcheckpoint\" in lightning_config:\n            modelckpt_cfg = lightning_config.modelcheckpoint\n        else:\n            modelckpt_cfg =  OmegaConf.create()\n        modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)\n        print(f\"Merged modelckpt-cfg: \\n{modelckpt_cfg}\")\n        if version.parse(pl.__version__) < version.parse('1.4.0'):\n            trainer_kwargs[\"checkpoint_callback\"] = instantiate_from_config(modelckpt_cfg)\n\n        # add callback which sets up log directory\n        default_callbacks_cfg = {\n            \"setup_callback\": {\n                \"target\": \"main.SetupCallback\",\n                \"params\": {\n                    \"resume\": opt.resume,\n                    \"now\": now,\n                    \"logdir\": logdir,\n                    \"ckptdir\": ckptdir,\n                    \"cfgdir\": cfgdir,\n                    \"config\": config,\n                    \"lightning_config\": lightning_config,\n                }\n            },\n            \"image_logger\": {\n                \"target\": \"main.ImageLogger\",\n                \"params\": {\n                    \"batch_frequency\": 750,\n                    \"max_images\": 4,\n                    \"clamp\": True\n                }\n            },\n            \"learning_rate_logger\": {\n                \"target\": \"main.LearningRateMonitor\",\n                \"params\": {\n                    \"logging_interval\": \"step\",\n                    # \"log_momentum\": True\n                }\n            },\n            \"cuda_callback\": {\n                \"target\": \"main.CUDACallback\"\n            },\n        }\n        if version.parse(pl.__version__) >= version.parse('1.4.0'):\n            default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})\n\n        if \"callbacks\" in lightning_config:\n            callbacks_cfg = lightning_config.callbacks\n        else:\n            callbacks_cfg = OmegaConf.create()\n\n        if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:\n            print(\n                'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')\n            default_metrics_over_trainsteps_ckpt_dict = {\n                'metrics_over_trainsteps_checkpoint':\n                    {\"target\": 'pytorch_lightning.callbacks.ModelCheckpoint',\n                     'params': {\n                         \"dirpath\": os.path.join(ckptdir, 'trainstep_checkpoints'),\n                         \"filename\": \"{epoch:06}-{step:09}\",\n                         \"verbose\": True,\n                         'save_top_k': -1,\n                         'every_n_train_steps': 10000,\n                         'save_weights_only': True\n                     }\n                     }\n            }\n            default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)\n\n        callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)\n        if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):\n            callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint\n        elif 'ignore_keys_callback' in callbacks_cfg:\n            del callbacks_cfg['ignore_keys_callback']\n\n        trainer_kwargs[\"callbacks\"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]\n        trainer_kwargs[\"max_steps\"] = trainer_opt.max_steps\n\n        trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)\n        trainer.logdir = logdir  ###\n\n        # data\n        config.data.params.train.params.data_root = opt.data_root\n        config.data.params.reg.params.data_root = opt.reg_data_root\n        config.data.params.validation.params.data_root = opt.data_root\n        data = instantiate_from_config(config.data)\n\n        data = instantiate_from_config(config.data)\n        # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html\n        # calling these ourselves should not be necessary but it is.\n        # lightning still takes care of proper multiprocessing though\n        data.prepare_data()\n        data.setup()\n        print(\"#### Data #####\")\n        for k in data.datasets:\n            print(f\"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}\")\n\n        # configure learning rate\n        bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate\n        if not cpu:\n            ngpu = len(lightning_config.trainer.gpus.strip(\",\").split(','))\n        else:\n            ngpu = 1\n        if 'accumulate_grad_batches' in lightning_config.trainer:\n            accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches\n        else:\n            accumulate_grad_batches = 1\n        print(f\"accumulate_grad_batches = {accumulate_grad_batches}\")\n        lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches\n        if opt.scale_lr:\n            model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr\n            print(\n                \"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)\".format(\n                    model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))\n        else:\n            model.learning_rate = base_lr\n            print(\"++++ NOT USING LR SCALING ++++\")\n            print(f\"Setting learning rate to {model.learning_rate:.2e}\")\n\n\n        # allow checkpointing via USR1\n        def melk(*args, **kwargs):\n            # run all checkpoint hooks\n            if trainer.global_rank == 0:\n                print(\"Summoning checkpoint.\")\n                ckpt_path = os.path.join(ckptdir, \"last.ckpt\")\n                trainer.save_checkpoint(ckpt_path)\n\n\n        def divein(*args, **kwargs):\n            if trainer.global_rank == 0:\n                import pudb;\n                pudb.set_trace()\n\n\n        import signal\n\n        signal.signal(signal.SIGUSR1, melk)\n        signal.signal(signal.SIGUSR2, divein)\n\n        # run\n        if opt.train:\n            try:\n                trainer.fit(model, data)\n            except Exception:\n                melk()\n                raise\n        if not opt.no_test and not trainer.interrupted:\n            trainer.test(model, data)\n    except Exception:\n        if opt.debug and trainer.global_rank == 0:\n            try:\n                import ipdb as debugger\n            except ImportError:\n                import pdb as debugger\n            debugger.post_mortem()\n        raise\n    finally:\n        # move newly created debug project to debug_runs\n        if opt.debug and not opt.resume and trainer.global_rank == 0:\n            dst, name = os.path.split(logdir)\n            dst = os.path.join(dst, \"debug_runs\", name)\n            os.makedirs(os.path.split(dst)[0], exist_ok=True)\n            os.rename(logdir, dst)\n        if trainer.global_rank == 0:\n            print(trainer.profiler.summary())\n"
  },
  {
    "path": "scripts/download_first_stages.sh",
    "content": "#!/bin/bash\nwget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip\nwget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip\nwget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip\nwget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip\nwget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip\nwget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip\nwget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip\nwget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip\nwget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip\n\n\n\ncd models/first_stage_models/kl-f4\nunzip -o model.zip\n\ncd ../kl-f8\nunzip -o model.zip\n\ncd ../kl-f16\nunzip -o model.zip\n\ncd ../kl-f32\nunzip -o model.zip\n\ncd ../vq-f4\nunzip -o model.zip\n\ncd ../vq-f4-noattn\nunzip -o model.zip\n\ncd ../vq-f8\nunzip -o model.zip\n\ncd ../vq-f8-n256\nunzip -o model.zip\n\ncd ../vq-f16\nunzip -o model.zip\n\ncd ../.."
  },
  {
    "path": "scripts/download_models.sh",
    "content": "#!/bin/bash\nwget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip\nwget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip\nwget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip\nwget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip\nwget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip\nwget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip\nwget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip\nwget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip\nwget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip\nwget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip\nwget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip\n\n\n\ncd models/ldm/celeba256\nunzip -o celeba-256.zip\n\ncd ../ffhq256\nunzip -o ffhq-256.zip\n\ncd ../lsun_churches256\nunzip -o lsun_churches-256.zip\n\ncd ../lsun_beds256\nunzip -o lsun_beds-256.zip\n\ncd ../text2img256\nunzip -o model.zip\n\ncd ../cin256\nunzip -o model.zip\n\ncd ../semantic_synthesis512\nunzip -o model.zip\n\ncd ../semantic_synthesis256\nunzip -o model.zip\n\ncd ../bsr_sr\nunzip -o model.zip\n\ncd ../layout2img-openimages256\nunzip -o model.zip\n\ncd ../inpainting_big\nunzip -o model.zip\n\ncd ../..\n"
  },
  {
    "path": "scripts/sample_diffusion.py",
    "content": "import argparse, os, sys, glob, datetime, yaml\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import trange\n\nfrom omegaconf import OmegaConf\nfrom PIL import Image\n\nfrom ldm.models.diffusion.ddim import DDIMSampler\nfrom ldm.util import instantiate_from_config\n\nrescale = lambda x: (x + 1.) / 2.\n\ndef custom_to_pil(x):\n    x = x.detach().cpu()\n    x = torch.clamp(x, -1., 1.)\n    x = (x + 1.) / 2.\n    x = x.permute(1, 2, 0).numpy()\n    x = (255 * x).astype(np.uint8)\n    x = Image.fromarray(x)\n    if not x.mode == \"RGB\":\n        x = x.convert(\"RGB\")\n    return x\n\n\ndef custom_to_np(x):\n    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py\n    sample = x.detach().cpu()\n    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)\n    sample = sample.permute(0, 2, 3, 1)\n    sample = sample.contiguous()\n    return sample\n\n\ndef logs2pil(logs, keys=[\"sample\"]):\n    imgs = dict()\n    for k in logs:\n        try:\n            if len(logs[k].shape) == 4:\n                img = custom_to_pil(logs[k][0, ...])\n            elif len(logs[k].shape) == 3:\n                img = custom_to_pil(logs[k])\n            else:\n                print(f\"Unknown format for key {k}. \")\n                img = None\n        except:\n            img = None\n        imgs[k] = img\n    return imgs\n\n\n@torch.no_grad()\ndef convsample(model, shape, return_intermediates=True,\n               verbose=True,\n               make_prog_row=False):\n\n\n    if not make_prog_row:\n        return model.p_sample_loop(None, shape,\n                                   return_intermediates=return_intermediates, verbose=verbose)\n    else:\n        return model.progressive_denoising(\n            None, shape, verbose=True\n        )\n\n\n@torch.no_grad()\ndef convsample_ddim(model, steps, shape, eta=1.0\n                    ):\n    ddim = DDIMSampler(model)\n    bs = shape[0]\n    shape = shape[1:]\n    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)\n    return samples, intermediates\n\n\n@torch.no_grad()\ndef make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):\n\n\n    log = dict()\n\n    shape = [batch_size,\n             model.model.diffusion_model.in_channels,\n             model.model.diffusion_model.image_size,\n             model.model.diffusion_model.image_size]\n\n    with model.ema_scope(\"Plotting\"):\n        t0 = time.time()\n        if vanilla:\n            sample, progrow = convsample(model, shape,\n                                         make_prog_row=True)\n        else:\n            sample, intermediates = convsample_ddim(model,  steps=custom_steps, shape=shape,\n                                                    eta=eta)\n\n        t1 = time.time()\n\n    x_sample = model.decode_first_stage(sample)\n\n    log[\"sample\"] = x_sample\n    log[\"time\"] = t1 - t0\n    log['throughput'] = sample.shape[0] / (t1 - t0)\n    print(f'Throughput for this batch: {log[\"throughput\"]}')\n    return log\n\ndef run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):\n    if vanilla:\n        print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')\n    else:\n        print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')\n\n\n    tstart = time.time()\n    n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1\n    # path = logdir\n    if model.cond_stage_model is None:\n        all_images = []\n\n        print(f\"Running unconditional sampling for {n_samples} samples\")\n        for _ in trange(n_samples // batch_size, desc=\"Sampling Batches (unconditional)\"):\n            logs = make_convolutional_sample(model, batch_size=batch_size,\n                                             vanilla=vanilla, custom_steps=custom_steps,\n                                             eta=eta)\n            n_saved = save_logs(logs, logdir, n_saved=n_saved, key=\"sample\")\n            all_images.extend([custom_to_np(logs[\"sample\"])])\n            if n_saved >= n_samples:\n                print(f'Finish after generating {n_saved} samples')\n                break\n        all_img = np.concatenate(all_images, axis=0)\n        all_img = all_img[:n_samples]\n        shape_str = \"x\".join([str(x) for x in all_img.shape])\n        nppath = os.path.join(nplog, f\"{shape_str}-samples.npz\")\n        np.savez(nppath, all_img)\n\n    else:\n       raise NotImplementedError('Currently only sampling for unconditional models supported.')\n\n    print(f\"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.\")\n\n\ndef save_logs(logs, path, n_saved=0, key=\"sample\", np_path=None):\n    for k in logs:\n        if k == key:\n            batch = logs[key]\n            if np_path is None:\n                for x in batch:\n                    img = custom_to_pil(x)\n                    imgpath = os.path.join(path, f\"{key}_{n_saved:06}.png\")\n                    img.save(imgpath)\n                    n_saved += 1\n            else:\n                npbatch = custom_to_np(batch)\n                shape_str = \"x\".join([str(x) for x in npbatch.shape])\n                nppath = os.path.join(np_path, f\"{n_saved}-{shape_str}-samples.npz\")\n                np.savez(nppath, npbatch)\n                n_saved += npbatch.shape[0]\n    return n_saved\n\n\ndef get_parser():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"-r\",\n        \"--resume\",\n        type=str,\n        nargs=\"?\",\n        help=\"load from logdir or checkpoint in logdir\",\n    )\n    parser.add_argument(\n        \"-n\",\n        \"--n_samples\",\n        type=int,\n        nargs=\"?\",\n        help=\"number of samples to draw\",\n        default=50000\n    )\n    parser.add_argument(\n        \"-e\",\n        \"--eta\",\n        type=float,\n        nargs=\"?\",\n        help=\"eta for ddim sampling (0.0 yields deterministic sampling)\",\n        default=1.0\n    )\n    parser.add_argument(\n        \"-v\",\n        \"--vanilla_sample\",\n        default=False,\n        action='store_true',\n        help=\"vanilla sampling (default option is DDIM sampling)?\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--logdir\",\n        type=str,\n        nargs=\"?\",\n        help=\"extra logdir\",\n        default=\"none\"\n    )\n    parser.add_argument(\n        \"-c\",\n        \"--custom_steps\",\n        type=int,\n        nargs=\"?\",\n        help=\"number of steps for ddim and fastdpm sampling\",\n        default=50\n    )\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        nargs=\"?\",\n        help=\"the bs\",\n        default=10\n    )\n    return parser\n\n\ndef load_model_from_config(config, sd):\n    model = instantiate_from_config(config)\n    model.load_state_dict(sd,strict=False)\n    model.cuda()\n    model.eval()\n    return model\n\n\ndef load_model(config, ckpt, gpu, eval_mode):\n    if ckpt:\n        print(f\"Loading model from {ckpt}\")\n        pl_sd = torch.load(ckpt, map_location=\"cpu\")\n        global_step = pl_sd[\"global_step\"]\n    else:\n        pl_sd = {\"state_dict\": None}\n        global_step = None\n    model = load_model_from_config(config.model,\n                                   pl_sd[\"state_dict\"])\n\n    return model, global_step\n\n\nif __name__ == \"__main__\":\n    now = datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n    sys.path.append(os.getcwd())\n    command = \" \".join(sys.argv)\n\n    parser = get_parser()\n    opt, unknown = parser.parse_known_args()\n    ckpt = None\n\n    if not os.path.exists(opt.resume):\n        raise ValueError(\"Cannot find {}\".format(opt.resume))\n    if os.path.isfile(opt.resume):\n        # paths = opt.resume.split(\"/\")\n        try:\n            logdir = '/'.join(opt.resume.split('/')[:-1])\n            # idx = len(paths)-paths[::-1].index(\"logs\")+1\n            print(f'Logdir is {logdir}')\n        except ValueError:\n            paths = opt.resume.split(\"/\")\n            idx = -2  # take a guess: path/to/logdir/checkpoints/model.ckpt\n            logdir = \"/\".join(paths[:idx])\n        ckpt = opt.resume\n    else:\n        assert os.path.isdir(opt.resume), f\"{opt.resume} is not a directory\"\n        logdir = opt.resume.rstrip(\"/\")\n        ckpt = os.path.join(logdir, \"model.ckpt\")\n\n    base_configs = sorted(glob.glob(os.path.join(logdir, \"config.yaml\")))\n    opt.base = base_configs\n\n    configs = [OmegaConf.load(cfg) for cfg in opt.base]\n    cli = OmegaConf.from_dotlist(unknown)\n    config = OmegaConf.merge(*configs, cli)\n\n    gpu = True\n    eval_mode = True\n\n    if opt.logdir != \"none\":\n        locallog = logdir.split(os.sep)[-1]\n        if locallog == \"\": locallog = logdir.split(os.sep)[-2]\n        print(f\"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'\")\n        logdir = os.path.join(opt.logdir, locallog)\n\n    print(config)\n\n    model, global_step = load_model(config, ckpt, gpu, eval_mode)\n    print(f\"global step: {global_step}\")\n    print(75 * \"=\")\n    print(\"logging to:\")\n    logdir = os.path.join(logdir, \"samples\", f\"{global_step:08}\", now)\n    imglogdir = os.path.join(logdir, \"img\")\n    numpylogdir = os.path.join(logdir, \"numpy\")\n\n    os.makedirs(imglogdir)\n    os.makedirs(numpylogdir)\n    print(logdir)\n    print(75 * \"=\")\n\n    # write config out\n    sampling_file = os.path.join(logdir, \"sampling_config.yaml\")\n    sampling_conf = vars(opt)\n\n    with open(sampling_file, 'w') as f:\n        yaml.dump(sampling_conf, f, default_flow_style=False)\n    print(sampling_conf)\n\n\n    run(model, imglogdir, eta=opt.eta,\n        vanilla=opt.vanilla_sample,  n_samples=opt.n_samples, custom_steps=opt.custom_steps,\n        batch_size=opt.batch_size, nplog=numpylogdir)\n\n    print(\"done.\")\n"
  },
  {
    "path": "scripts/score.py",
    "content": "import clip\nimport torch\nimport os\nfrom PIL import Image\nimport lpips\nfrom torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize\nfrom einops import repeat\nimport numpy as np\nimport pickle\ntry:\n    from torchvision.transforms import InterpolationMode\n    BICUBIC = InterpolationMode.BICUBIC\nexcept ImportError:\n    BICUBIC = Image.BICUBIC\n    \nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--img_dir', type=str, required=True)\nparser.add_argument('--mode', type=str, default='K', choices=['K', 'beta'])\nparser.add_argument('--prompt', type=str, required=True)\nparser.add_argument('--orig_img', type=str, required=True)\nopt = parser.parse_args()\n\n\ndef _transform():\n    return Compose([\n        Resize((512, 512), interpolation=BICUBIC),\n        _convert_image_to_rgb,\n        ToTensor(),\n        # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),\n        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n    ])\n    \ndef _convert_image_to_rgb(image):\n    return image.convert(\"RGB\")\n\nORIG_IMAGE_PATH = opt.orig_img\nTGT_TEXT = opt.prompt\nIMAGE_DIR = opt.img_dir\n\n# import ipdb\n# ipdb.set_trace()\nif opt.mode == 'K':\n    Xs = [0, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]\nelse:\n    Xs = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]\nNUM_IMGS_PER_EXP = 20\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\nimg_paths = [os.path.join(IMAGE_DIR, f) for f in sorted(os.listdir(IMAGE_DIR)) if os.path.isfile(os.path.join(IMAGE_DIR, f))]\n\nloss_fn = lpips.LPIPS(net='alex').to(device)\norig_img = Image.open(ORIG_IMAGE_PATH)\ntransform = _transform()\norig_img = transform(orig_img).unsqueeze(0).to(device)\n# orig_img = repeat(orig_img, '1 c h w -> n c h w', n=len(img_paths))\nimgs = [transform(Image.open(img_path)).unsqueeze(0).to(device) for img_path in img_paths]\nlpips_scores = [loss_fn(img, orig_img).item() for img in imgs]\nprint(lpips_scores)\n\nmodel, preprocess = clip.load(\"ViT-B/32\", device=device)\n\n\nimgs = [preprocess(Image.open(img_path)).unsqueeze(0).to(device) for img_path in img_paths]\ntext_tokens = clip.tokenize(TGT_TEXT).to(device)\nwith torch.no_grad():\n    text_features = model.encode_text(text_tokens)\n    img_features = model.encode_image(torch.cat(imgs, dim=0))\n    img_features /= img_features.norm(dim=-1, keepdim=True)\n    text_features /= text_features.norm(dim=-1, keepdim=True)\n    similarity = (100.0 * img_features @ text_features.T).squeeze()\n\nprint(similarity)\nl_score_exps = list()\nc_score_exps = list()\nx_axis = list()\nfor i in range(len(img_paths)):\n    if i % NUM_IMGS_PER_EXP == 0:\n        l_score_exps.append(list())\n        c_score_exps.append(list())\n        x_axis.append([Xs[i//NUM_IMGS_PER_EXP]] * NUM_IMGS_PER_EXP)\n    l_score_exps[-1].append(lpips_scores[i])\n    c_score_exps[-1].append(similarity[i].item())\n    \nresult = dict()\nresult['l_score_exps'] = l_score_exps\nresult['c_score_exps'] = c_score_exps\nresult['x_axis'] = x_axis\n\npickle.dump(result, open(os.path.join(os.path.dirname(IMAGE_DIR), 'scores.pkl'), 'wb'))\n\n# import ipdb\n# ipdb.set_trace()\n# print([np.mean(l_score) for l_score in l_score_exps])\n# print([np.mean(c_score) for c_score in c_score_exps])\n\n    \n\n\n"
  },
  {
    "path": "scripts/stable_txt2img_guidance.py",
    "content": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom tqdm import tqdm, trange\nfrom itertools import islice\nfrom einops import rearrange\nfrom torchvision.utils import make_grid, save_image\nimport time\nfrom pytorch_lightning import seed_everything\nfrom torch import autocast\nfrom contextlib import contextmanager, nullcontext\n\nfrom ldm.util import instantiate_from_config\nfrom ldm.models.diffusion.ddim import DDIMSampler\nfrom ldm.models.diffusion.plms import PLMSSampler\n\n\ndef chunk(it, size):\n    it = iter(it)\n    return iter(lambda: tuple(islice(it, size)), ())\n\n\ndef load_model_from_config(config, ckpt, verbose=False):\n    print(f\"Loading model from {ckpt}\")\n    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n    if \"global_step\" in pl_sd:\n        print(f\"Global Step: {pl_sd['global_step']}\")\n    sd = pl_sd[\"state_dict\"]\n    model = instantiate_from_config(config.model)\n    m, u = model.load_state_dict(sd, strict=False)\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    model.cuda()\n    model.eval()\n    return model\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--prompt\",\n        type=str,\n        nargs=\"?\",\n        default=\"a painting of a virus monster playing guitar\",\n        help=\"the prompt to render\"\n    )\n    parser.add_argument(\n        \"--outdir\",\n        type=str,\n        nargs=\"?\",\n        help=\"dir to write results to\",\n        default=\"outputs/txt2img-samples\"\n    )\n    parser.add_argument(\n        \"--skip_grid\",\n        action='store_true',\n        help=\"do not save a grid, only individual samples. Helpful when evaluating lots of samples\",\n    )\n    parser.add_argument(\n        \"--skip_save\",\n        action='store_true',\n        help=\"do not save individual samples. For speed measurements.\",\n    )\n    parser.add_argument(\n        \"--ddim_steps\",\n        type=int,\n        default=50,\n        help=\"number of ddim sampling steps\",\n    )\n    parser.add_argument(\n        \"--plms\",\n        action='store_true',\n        help=\"use plms sampling\",\n    )\n    parser.add_argument(\n        \"--laion400m\",\n        action='store_true',\n        help=\"uses the LAION400M model\",\n    )\n    parser.add_argument(\n        \"--fixed_code\",\n        action='store_true',\n        help=\"if enabled, uses the same starting code across samples \",\n    )\n    parser.add_argument(\n        \"--ddim_eta\",\n        type=float,\n        default=0.0,\n        help=\"ddim eta (eta=0.0 corresponds to deterministic sampling\",\n    )\n    parser.add_argument(\n        \"--n_iter\",\n        type=int,\n        default=2,\n        help=\"sample this often\",\n    )\n    parser.add_argument(\n        \"--H\",\n        type=int,\n        default=512,\n        help=\"image height, in pixel space\",\n    )\n    parser.add_argument(\n        \"--W\",\n        type=int,\n        default=512,\n        help=\"image width, in pixel space\",\n    )\n    parser.add_argument(\n        \"--C\",\n        type=int,\n        default=4,\n        help=\"latent channels\",\n    )\n    parser.add_argument(\n        \"--f\",\n        type=int,\n        default=8,\n        help=\"downsampling factor\",\n    )\n    parser.add_argument(\n        \"--n_samples\",\n        type=int,\n        default=3,\n        help=\"how many samples to produce for each given prompt. A.k.a. batch size\",\n    )\n    parser.add_argument(\n        \"--n_rows\",\n        type=int,\n        default=0,\n        help=\"rows in the grid (default: n_samples)\",\n    )\n    parser.add_argument(\n        \"--scale\",\n        type=float,\n        default=7.5,\n        help=\"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))\",\n    )\n    parser.add_argument(\n        \"--from-file\",\n        type=str,\n        help=\"if specified, load prompts from this file\",\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"configs/stable-diffusion/v1-inference.yaml\",\n        help=\"path to config which constructs model\",\n    )\n    parser.add_argument(\n        \"--ckpt\",\n        type=str,\n        default=\"models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt\",\n        help=\"path to checkpoint of model\",\n    )   \n    \n    parser.add_argument(\n        \"--sin_config\",\n        type=str,\n        default=\"configs/stable-diffusion/v1-inference.yaml\",\n    )\n    parser.add_argument(\n        \"--sin_ckpt\",\n        type=str,\n        required=True,\n    )\n     \n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=42,\n        help=\"the seed (for reproducible sampling)\",\n    )\n    parser.add_argument(\n        \"--precision\",\n        type=str,\n        help=\"evaluate at this precision\",\n        choices=[\"full\", \"autocast\"],\n        default=\"autocast\"\n    )\n    parser.add_argument(\"--single_guidance\", action=\"store_true\")\n    parser.add_argument(\"--range_t_max\", type=int, default=400)\n    parser.add_argument(\"--range_t_min\", type=int, default=1)\n    parser.add_argument(\"--cond_beta\", type=float, default=0.5)\n\n    parser.add_argument(\n        \"--embedding_path\", \n        type=str, \n        help=\"Path to a pre-trained embedding manager checkpoint\")\n\n    opt = parser.parse_args()\n\n    if opt.laion400m:\n        print(\"Falling back to LAION 400M model...\")\n        opt.config = \"configs/latent-diffusion/txt2img-1p4B-eval.yaml\"\n        opt.ckpt = \"models/ldm/text2img-large/model.ckpt\"\n        opt.outdir = \"outputs/txt2img-samples-laion400m\"\n\n    seed_everything(opt.seed)\n\n    config = OmegaConf.load(f\"{opt.config}\")\n    model = load_model_from_config(config, f\"{opt.ckpt}\")\n    opt.cond_beta_sin = 1. - opt.cond_beta\n    model.extra_config = vars(opt)\n    #model.embedding_manager.load(opt.embedding_path)\n\n    device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    model = model.to(device)\n    \n    \n    sin_config = OmegaConf.load(f\"{opt.sin_config}\")\n    sin_model = load_model_from_config(sin_config, f\"{opt.sin_ckpt}\")\n    sin_model = sin_model.to(device)\n\n    if opt.plms:\n        sampler = PLMSSampler(model)\n    else:\n        if opt.single_guidance:\n            from ldm.models.diffusion.guidance_ddim import DDIMSinSampler\n            sampler = DDIMSinSampler(model, sin_model)\n        else:\n            sampler = DDIMSampler(model)\n\n    os.makedirs(opt.outdir, exist_ok=True)\n    outpath = opt.outdir\n\n    batch_size = opt.n_samples\n    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size\n    if not opt.from_file:\n        prompt = opt.prompt\n        assert prompt is not None\n        data = [batch_size * [prompt]]\n\n    else:\n        print(f\"reading prompts from {opt.from_file}\")\n        with open(opt.from_file, \"r\") as f:\n            data = f.read().splitlines()\n            data = list(chunk(data, batch_size))\n\n    sample_path = os.path.join(outpath, \"samples\")\n    os.makedirs(sample_path, exist_ok=True)\n    base_count = len(os.listdir(sample_path))\n    grid_count = len(os.listdir(outpath)) - 1\n\n    start_code = None\n    if opt.fixed_code:\n        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)\n\n    precision_scope = autocast if opt.precision==\"autocast\" else nullcontext\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                with sin_model.ema_scope():\n                    tic = time.time()\n                    all_samples = list()\n                    for n in trange(opt.n_iter, desc=\"Sampling\"):\n                        for prompts in tqdm(data, desc=\"data\"):\n                            \n                            \n                            uc = None\n                            if opt.scale != 1.0:\n                                uc = model.get_learned_conditioning(batch_size * [\"\"])\n                                uc_sin = sin_model.get_learned_conditioning(batch_size * [\"\"])\n                            if isinstance(prompts, tuple):\n                                prompts = list(prompts)\n                                \n                            if opt.single_guidance:\n                                b = len(prompts)\n                                prompt = prompts[0]\n                                prompts = [prompt.split('[SEP]')[0].strip()] * b\n                                prompts_single = [prompt.split('[SEP]')[1].strip()] * b\n                            \n                            c = model.get_learned_conditioning(prompts)\n                            c_sin = sin_model.get_learned_conditioning(prompts_single)\n                            \n                            \n                            shape = [opt.C, opt.H // opt.f, opt.W // opt.f]\n                            samples_ddim, _ = sampler.sample(S=opt.ddim_steps,\n                                                            conditioning=c,\n                                                            conditioning_single=c_sin,\n                                                            batch_size=opt.n_samples,\n                                                            shape=shape,\n                                                            verbose=False,\n                                                            unconditional_guidance_scale=opt.scale,\n                                                            unconditional_conditioning=uc,\n                                                            unconditional_conditioning_single=uc_sin,\n                                                            eta=opt.ddim_eta,\n                                                            x_T=start_code)\n\n                            x_samples_ddim = model.decode_first_stage(samples_ddim)\n                            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n\n                            if not opt.skip_save:\n                                for x_sample in x_samples_ddim:\n                                    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n                                    Image.fromarray(x_sample.astype(np.uint8)).save(\n                                        os.path.join(sample_path, f\"{base_count:05}.jpg\"))\n                                    base_count += 1\n\n                            if not opt.skip_grid:\n                                all_samples.append(x_samples_ddim)\n\n                    if not opt.skip_grid:\n                        # additionally, save as grid\n                        grid = torch.stack(all_samples, 0)\n                        grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n                        \n                        for i in range(grid.size(0)):\n                            save_image(grid[i, :, :, :], os.path.join(outpath,opt.prompt+'_{}.png'.format(i)))\n                        grid = make_grid(grid, nrow=n_rows)\n\n                        # to image\n                        grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n                        Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(\" \", \"-\")}-{grid_count:04}.jpg'))\n                        grid_count += 1\n                    \n                    \n\n                toc = time.time()\n\n    print(f\"Your samples are ready and waiting for you here: \\n{outpath} \\n\"\n          f\" \\nEnjoy.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/stable_txt2img_multi_guidance.py",
    "content": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom tqdm import tqdm, trange\nfrom itertools import islice\nfrom einops import rearrange\nfrom torchvision.utils import make_grid, save_image\nimport time\nfrom pytorch_lightning import seed_everything\nfrom torch import autocast\nfrom contextlib import contextmanager, nullcontext\n\nfrom ldm.util import instantiate_from_config\nfrom ldm.models.diffusion.ddim import DDIMSampler\nfrom ldm.models.diffusion.plms import PLMSSampler\n\n\ndef chunk(it, size):\n    it = iter(it)\n    return iter(lambda: tuple(islice(it, size)), ())\n\n\ndef load_model_from_config(config, ckpt, verbose=False):\n    print(f\"Loading model from {ckpt}\")\n    pl_sd = torch.load(ckpt, map_location=\"cpu\")\n    if \"global_step\" in pl_sd:\n        print(f\"Global Step: {pl_sd['global_step']}\")\n    sd = pl_sd[\"state_dict\"]\n    model = instantiate_from_config(config.model)\n    m, u = model.load_state_dict(sd, strict=False)\n    if len(m) > 0 and verbose:\n        print(\"missing keys:\")\n        print(m)\n    if len(u) > 0 and verbose:\n        print(\"unexpected keys:\")\n        print(u)\n\n    model.cuda()\n    model.eval()\n    return model\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument(\n        \"--prompt\",\n        type=str,\n        nargs=\"?\",\n        default=\"a painting of a virus monster playing guitar\",\n        help=\"the prompt to render\"\n    )\n    parser.add_argument(\n        \"--outdir\",\n        type=str,\n        nargs=\"?\",\n        help=\"dir to write results to\",\n        default=\"outputs/txt2img-samples\"\n    )\n    parser.add_argument(\n        \"--skip_grid\",\n        action='store_true',\n        help=\"do not save a grid, only individual samples. Helpful when evaluating lots of samples\",\n    )\n    parser.add_argument(\n        \"--skip_save\",\n        action='store_true',\n        help=\"do not save individual samples. For speed measurements.\",\n    )\n    parser.add_argument(\n        \"--ddim_steps\",\n        type=int,\n        default=50,\n        help=\"number of ddim sampling steps\",\n    )\n    parser.add_argument(\n        \"--plms\",\n        action='store_true',\n        help=\"use plms sampling\",\n    )\n    parser.add_argument(\n        \"--laion400m\",\n        action='store_true',\n        help=\"uses the LAION400M model\",\n    )\n    parser.add_argument(\n        \"--fixed_code\",\n        action='store_true',\n        help=\"if enabled, uses the same starting code across samples \",\n    )\n    parser.add_argument(\n        \"--ddim_eta\",\n        type=float,\n        default=0.0,\n        help=\"ddim eta (eta=0.0 corresponds to deterministic sampling\",\n    )\n    parser.add_argument(\n        \"--n_iter\",\n        type=int,\n        default=2,\n        help=\"sample this often\",\n    )\n    parser.add_argument(\n        \"--H\",\n        type=int,\n        default=512,\n        help=\"image height, in pixel space\",\n    )\n    parser.add_argument(\n        \"--W\",\n        type=int,\n        default=512,\n        help=\"image width, in pixel space\",\n    )\n    parser.add_argument(\n        \"--C\",\n        type=int,\n        default=4,\n        help=\"latent channels\",\n    )\n    parser.add_argument(\n        \"--f\",\n        type=int,\n        default=8,\n        help=\"downsampling factor\",\n    )\n    parser.add_argument(\n        \"--n_samples\",\n        type=int,\n        default=3,\n        help=\"how many samples to produce for each given prompt. A.k.a. batch size\",\n    )\n    parser.add_argument(\n        \"--n_rows\",\n        type=int,\n        default=0,\n        help=\"rows in the grid (default: n_samples)\",\n    )\n    parser.add_argument(\n        \"--scale\",\n        type=float,\n        default=7.5,\n        help=\"unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))\",\n    )\n    parser.add_argument(\n        \"--from-file\",\n        type=str,\n        help=\"if specified, load prompts from this file\",\n    )\n    parser.add_argument(\n        \"--config\",\n        type=str,\n        default=\"configs/stable-diffusion/v1-inference.yaml\",\n        help=\"path to config which constructs model\",\n    )\n    parser.add_argument(\n        \"--ckpt\",\n        type=str,\n        default=\"models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt\",\n        help=\"path to checkpoint of model\",\n    )   \n    \n    parser.add_argument(\n        \"--sin_config\",\n        type=str,\n        default=\"configs/stable-diffusion/v1-inference.yaml\",\n        nargs='+'\n    )\n    parser.add_argument(\n        \"--sin_ckpt\",\n        type=str,\n        required=True,\n        nargs='+'\n    )\n     \n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=42,\n        help=\"the seed (for reproducible sampling)\",\n    )\n    parser.add_argument(\n        \"--precision\",\n        type=str,\n        help=\"evaluate at this precision\",\n        choices=[\"full\", \"autocast\"],\n        default=\"autocast\"\n    )\n    parser.add_argument(\"--single_guidance\", action=\"store_true\")\n    parser.add_argument(\"--range_t_max\", type=int, default=400, nargs='+')\n    parser.add_argument(\"--range_t_min\", type=int, default=1, nargs='+')\n    parser.add_argument(\"--beta\", type=float, default=0.5, nargs='+')\n\n    parser.add_argument(\n        \"--embedding_path\", \n        type=str, \n        help=\"Path to a pre-trained embedding manager checkpoint\")\n\n    opt = parser.parse_args()\n\n    if opt.laion400m:\n        print(\"Falling back to LAION 400M model...\")\n        opt.config = \"configs/latent-diffusion/txt2img-1p4B-eval.yaml\"\n        opt.ckpt = \"models/ldm/text2img-large/model.ckpt\"\n        opt.outdir = \"outputs/txt2img-samples-laion400m\"\n\n    seed_everything(opt.seed)\n    # import ipdb\n    # ipdb.set_trace()\n\n    config = OmegaConf.load(f\"{opt.config}\")\n    model = load_model_from_config(config, f\"{opt.ckpt}\")\n    #model.embedding_manager.load(opt.embedding_path)\n\n    device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n    model = model.to(device)\n    \n    guidance_configs = opt.sin_config\n    guidance_ckpts = opt.sin_ckpt\n    # import ipdb\n    # ipdb.set_trace()\n    assert len(guidance_configs) == len(guidance_ckpts)\n    \n    num_guidance = len(guidance_configs)\n    \n    guidance_betas = opt.beta if isinstance(opt.beta, list) else [opt.beta]*num_guidance\n    guidance_range_t_max = opt.range_t_max if isinstance(opt.range_t_max, list) else [opt.range_t_max]*num_guidance\n    guidance_range_t_min = opt.range_t_min if isinstance(opt.range_t_min, list) else [opt.range_t_min]*num_guidance\n    assert len(guidance_betas) == num_guidance and len(guidance_range_t_max) == num_guidance and len(guidance_range_t_min) == num_guidance\n    \n    guidance_models = [load_model_from_config(OmegaConf.load(config), ckpt).to(device) for config, ckpt in zip(guidance_configs, guidance_ckpts)]\n    for sin_model, beta, t_max, t_min in zip(guidance_models, guidance_betas, guidance_range_t_max, guidance_range_t_min):\n        sin_model.extra_config = {\"beta\": beta, \"range_t_max\": t_max, \"range_t_min\": t_min}\n    \n    from ldm.models.diffusion.guidance_ddim import DDIMMultiSampler\n    sampler = DDIMMultiSampler(model, guidance_models)\n\n    os.makedirs(opt.outdir, exist_ok=True)\n    outpath = opt.outdir\n\n    batch_size = opt.n_samples\n    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size\n    if not opt.from_file:\n        prompt = opt.prompt\n        assert prompt is not None\n        data = [batch_size * [prompt]]\n\n    else:\n        print(f\"reading prompts from {opt.from_file}\")\n        with open(opt.from_file, \"r\") as f:\n            data = f.read().splitlines()\n            data = list(chunk(data, batch_size))\n\n    sample_path = os.path.join(outpath, \"samples\")\n    os.makedirs(sample_path, exist_ok=True)\n    base_count = len(os.listdir(sample_path))\n    grid_count = len(os.listdir(outpath)) - 1\n\n    start_code = None\n    if opt.fixed_code:\n        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)\n\n    precision_scope = autocast if opt.precision==\"autocast\" else nullcontext\n    with torch.no_grad():\n        with precision_scope(\"cuda\"):\n            with model.ema_scope():\n                for sin_model in guidance_models:\n                    sin_model.ema_scope()\n                # with guidance_models[0].ema_scope():\n                tic = time.time()\n                all_samples = list()\n                for n in trange(opt.n_iter, desc=\"Sampling\"):\n                    for prompts in tqdm(data, desc=\"data\"):\n                        # import ipdb\n                        # ipdb.set_trace()\n                        uc = None\n                        if opt.scale != 1.0:\n                            uc = model.get_learned_conditioning(batch_size * [\"\"])\n                            uc_sin_list = [sin_model.get_learned_conditioning(batch_size * [\"\"]) for sin_model in guidance_models]\n                        if isinstance(prompts, tuple):\n                            prompts = list(prompts)\n                            \n                        if opt.single_guidance:\n                            b = len(prompts)\n                            prompt = prompts[0]\n                            prompts = [prompt.split('[SEP]')[0].strip()] * b\n                            \n                            prompts_single = [[p.strip()] * b for p in prompt.split('[SEP]')[1:]]\n                            assert len(prompts_single) == num_guidance\n                        # import ipdb\n                        # ipdb.set_trace()\n                        c = model.get_learned_conditioning(prompts)\n                        c_sin_list = [sin_model.get_learned_conditioning(p) for sin_model, p in zip(guidance_models, prompts_single)]\n\n                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]\n                        samples_ddim, _ = sampler.sample(S=opt.ddim_steps,\n                                                        conditioning=c,\n                                                        conditioning_single_list=c_sin_list,\n                                                        batch_size=opt.n_samples,\n                                                        shape=shape,\n                                                        verbose=False,\n                                                        unconditional_guidance_scale=opt.scale,\n                                                        unconditional_conditioning=uc,\n                                                        unconditional_conditioning_single_list=uc_sin_list,\n                                                        eta=opt.ddim_eta,\n                                                        x_T=start_code)\n\n                        x_samples_ddim = model.decode_first_stage(samples_ddim)\n                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n\n                        if not opt.skip_save:\n                            for x_sample in x_samples_ddim:\n                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n                                Image.fromarray(x_sample.astype(np.uint8)).save(\n                                    os.path.join(sample_path, f\"{base_count:05}.jpg\"))\n                                base_count += 1\n\n                        if not opt.skip_grid:\n                            all_samples.append(x_samples_ddim)\n\n                if not opt.skip_grid:\n                    # additionally, save as grid\n                    grid = torch.stack(all_samples, 0)\n                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n                    \n                    for i in range(grid.size(0)):\n                        save_image(grid[i, :, :, :], os.path.join(outpath,opt.prompt+'_{}.png'.format(i)))\n                    grid = make_grid(grid, nrow=n_rows)\n\n                    # to image\n                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n                    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(\" \", \"-\")}-{grid_count:04}.jpg'))\n                    grid_count += 1\n                    \n                \n\n                toc = time.time()\n\n    print(f\"Your samples are ready and waiting for you here: \\n{outpath} \\n\"\n          f\" \\nEnjoy.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name='latent-diffusion',\n    version='0.0.1',\n    description='',\n    packages=find_packages(),\n    install_requires=[\n        'torch',\n        'numpy',\n        'tqdm',\n    ],\n)"
  }
]