Showing preview only (502K chars total). Download the full file or copy to clipboard to get everything.
Repository: zhang-zx/SINE
Branch: master
Commit: c0df9502b622
Files: 51
Total size: 480.7 KB
Directory structure:
gitextract_l03wpjs2/
├── .gitignore
├── LICENSE
├── README.md
├── SINE.ipynb
├── configs/
│ └── stable-diffusion/
│ ├── v1-finetune_painting.yaml
│ ├── v1-finetune_painting_style.yaml
│ ├── v1-finetune_patch_painting.yaml
│ ├── v1-finetune_patch_picture.yaml
│ ├── v1-finetune_picture.yaml
│ ├── v1-inference.yaml
│ ├── v1-inference_patch.yaml
│ └── v1-inference_patch_nearest_interp.yaml
├── diffusers_models.py
├── diffusers_sample.py
├── diffusers_train.py
├── environment.yml
├── ldm/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── personalized.py
│ │ └── personalized_painting.py
│ ├── lr_scheduler.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ ├── positional_encoding.py
│ │ │ └── util.py
│ │ ├── distributions/
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── embedding_manager.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ ├── modules.py
│ │ │ └── modules_bak.py
│ │ ├── image_degradation/
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── main.py
├── scripts/
│ ├── download_first_stages.sh
│ ├── download_models.sh
│ ├── sample_diffusion.py
│ ├── score.py
│ ├── stable_txt2img_guidance.py
│ └── stable_txt2img_multi_guidance.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# General
.DS_Store
.AppleDouble
.LSOverride
models
# models/
logs
logs/
exps/
datasets/
src/
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
*/logs/
*/wandb/
*samples/
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
## SINE <br><sub> <ins>SIN</ins>gle Image <ins>E</ins>diting with Text-to-Image Diffusion Models</sub>
[](https://colab.research.google.com/github/zhang-zx/SINE/blob/master/SINE.ipynb)
[Project](https://zhang-zx.github.io/SINE/) |
[ArXiv](https://arxiv.org/abs/2212.04489)
This respository contains the code for the CVPR 2023 paper [SINE: SINgle Image Editing with Text-to-Image Diffusion Models](https://arxiv.org/abs/2212.04489).
For more visualization results, please check our [webpage](https://zhang-zx.github.io/SINE/).
> **[SINE: SINgle Image Editing with Text-to-Image Diffusion Models](https://zhang-zx.github.io/SINE/)** \
> [Zhixing Zhang](https://zhang-zx.github.io/) <sup>1</sup>,
> [Ligong Han](https://phymhan.github.io/) <sup>1</sup>,
> [Arnab Ghosh](https://arnabgho.github.io/) <sup>2</sup>,
> [Dimitris Metaxas](https://people.cs.rutgers.edu/~dnm/) <sup>1</sup>,
> and [Jian Ren](https://alanspike.github.io/) <sup>2</sup> \
> <sup>1</sup> Rutgers University
> <sup>2</sup> Snap Inc.\
> CVPR 2023.
<div align="center">
<a><img src="assets/overview_finetuning.png" width="500" ></a>
<a><img src="assets/overview_editing.png" width="500" ></a>
</div>
## Setup
First, clone the repository and install the dependencies:
```bash
git clone git@github.com:zhang-zx/SINE.git
```
Then, install the dependencies following the [instructions](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1).
Alternatively, you can also try to use the following docker image.
```bash
docker pull sunggukcha/sine
```
To fine-tune the model, you need to download the [pre-trained model](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt).
### Data Preparation
The data we use in the paper can be found from [here](https://drive.google.com/drive/folders/1rGt5YTCwNgEag8MD_1wr9jrPpi1_8vfu?usp=sharing).
## Fine-tuning
### Fine-tuning w/o patch-based training scheme
```bash
IMG_PATH=path/to/image
CLS_WRD='coarse class word'
NAME='name of the experiment'
python main.py \
--base configs/stable-diffusion/v1-finetune_picture.yaml \
-t --actual_resume /path/to/pre-trained/model \
-n $NAME --gpus 0, --logdir ./logs \
--data_root $IMG_PATH \
--reg_data_root $IMG_PATH --class_word $CLS_WRD
```
### Fine-tuning with patch-based training scheme
```bash
IMG_PATH=path/to/image
CLS_WRD='coarse class word'
NAME='name of the experiment'
python main.py \
--base configs/stable-diffusion/v1-finetune_patch_picture.yaml \
-t --actual_resume /path/to/pre-trained/model \
-n $NAME --gpus 0, --logdir ./logs \
--data_root $IMG_PATH \
--reg_data_root $IMG_PATH --class_word $CLS_WRD
```
## Model-based Image Editing
### Editing with one model's guidance
```bash
LOG_DIR=/path/to/logdir
python scripts/stable_txt2img_guidance.py --ddim_eta 0.0 --n_iter 1 \
--scale 10 --ddim_steps 100 \
--sin_config configs/stable-diffusion/v1-inference.yaml \
--sin_ckpt $LOG_DIR"/checkpoints/last.ckpt" \
--prompt "prompt for pre-trained model[SEP]prompt for fine-tuned model" \
--cond_beta 0.4 \
--range_t_min 500 --range_t_max 1000 --single_guidance \
--skip_save --H 512 --W 512 --n_samples 2 \
--outdir $LOG_DIR
```
### Editing with multiple models' guidance
```bash
python scripts/stable_txt2img_multi_guidance.py --ddim_eta 0.0 --n_iter 2 \
--scale 10 --ddim_steps 100 \
--sin_ckpt path/to/ckpt1 path/to/ckpt2 \
--sin_config ./configs/stable-diffusion/v1-inference.yaml \
configs/stable-diffusion/v1-inference.yaml \
--prompt "prompt for pre-trained model[SEP]prompt for fine-tuned model1[SEP]prompt for fine-tuned model2" \
--beta 0.4 0.5 \
--range_t_min 400 400 --range_t_max 1000 1000 --single_guidance \
--H 512 --W 512 --n_samples 2 \
--outdir path/to/output_dir
```
## Diffusers library Example
The Diffusers Library support is still under development.
Results in our paper are obtained using previous code based on LDM.
### Training
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export IMG_PATH="path/to/image"
export OUTPUT_DIR="path/to/output_dir"
accelerate launch diffusers_train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_text_encoder \
--img_path=$IMG_PATH \
--output_dir=$OUTPUT_DIR \
--instance_prompt="prompt for fine-tuning" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate=1e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=NUMBERS_OF_STEPS \
--checkpointing_steps=FREQUENCY_FOR_CHECKPOINTING \
--patch_based_training # OPTIONAL: add this flag for patch-based training scheme
```
### Sampling
```bash
python diffusers_sample.py \
--pretrained_model_name_or_path "path/to/output_dir" \
--prompt "prompt for fine-tuned model" \
--editing_prompt 'prompt for pre-trained model'
```
## Visualization Results
Some of the editing results are shown below.
See more results on our [webpage](https://zhang-zx.github.io/SINE/).

## Acknowledgments
In this code we refer to the following implementations: [Dreambooth-Stable-Diffusion](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion) and [stable-diffusion](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1).
Implementation with the Diffusers Library support is highly based on [Dreambooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth).
Great thanks to them!
## Reference
If our work or code helps you, please consider to cite our paper. Thank you!
```BibTeX
@article{zhang2022sine,
title={SINE: SINgle Image Editing with Text-to-Image Diffusion Models},
author={Zhang, Zhixing and Han, Ligong and Ghosh, Arnab and Metaxas, Dimitris and Ren, Jian},
journal={arXiv preprint arXiv:2212.04489},
year={2022}
}
```
================================================
FILE: SINE.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Welcome to SINE: SINgle Image Editing with Text-to-Image Diffusion Models!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gpu_info = !nvidia-smi\n",
"gpu_info = '\\n'.join(gpu_info)\n",
"if gpu_info.find('failed') >= 0:\n",
" print('Not connected to a GPU')\n",
"else:\n",
" print(gpu_info)\n",
"\n",
"from psutil import virtual_memory\n",
"ram_gb = virtual_memory().total / 1e9\n",
"print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
"\n",
"if ram_gb < 20:\n",
" print('Not using a high-RAM runtime')\n",
"else:\n",
" print('You are using a high-RAM runtime!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 1: Setup required libraries and models. \n",
"This may take a few minutes.\n",
"\n",
"You may optionally enable downloads with pydrive in order to authenticate and avoid drive download limits when fetching the pre-trained model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#@title Setup\n",
"\n",
"import os\n",
"\n",
"from pydrive.auth import GoogleAuth\n",
"from pydrive.drive import GoogleDrive\n",
"from google.colab import auth\n",
"from oauth2client.client import GoogleCredentials\n",
"\n",
"from argparse import Namespace\n",
"\n",
"import sys\n",
"import numpy as np\n",
"\n",
"from PIL import Image\n",
"\n",
"import torch\n",
"import torchvision.transforms as transforms\n",
"\n",
"device = 'cuda'\n",
"\n",
"\n",
"# install requirements\n",
"!git clone https://github.com/zhang-zx/SINE.git sine_dir\n",
"\n",
"%cd sine_dir/\n",
"!pip uninstall -y torchtext\n",
"! pip install transformers==4.18.0 einops==0.4.1 omegaconf==2.1.1 torchmetrics==0.6.0 torch-fidelity==0.3.0 kornia==0.6 albumentations==1.1.0 opencv-python==4.2.0.34 imageio==2.14.1 setuptools==59.5.0 pillow==9.0.1 \n",
"! pip install torch==1.10.2 torchvision==0.11.3\n",
"! pip install pytorch-lightning==1.5.9\n",
"! pip install git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers\n",
"! pip install git+https://github.com/openai/CLIP.git@main#egg=clip\n",
"! pip install -e .\n",
"\n",
"\n",
"\n",
"download_with_pydrive = True \n",
" \n",
"class Downloader(object):\n",
" def __init__(self, use_pydrive):\n",
" self.use_pydrive = use_pydrive\n",
"\n",
" if self.use_pydrive:\n",
" self.authenticate()\n",
" \n",
" def authenticate(self):\n",
" auth.authenticate_user()\n",
" gauth = GoogleAuth()\n",
" gauth.credentials = GoogleCredentials.get_application_default()\n",
" self.drive = GoogleDrive(gauth)\n",
" \n",
" def download_file(self, file_id, file_dst):\n",
" if self.use_pydrive:\n",
" downloaded = self.drive.CreateFile({'id':file_id})\n",
" downloaded.FetchMetadata(fetch_all=True)\n",
" downloaded.GetContentFile(file_dst)\n",
" else:\n",
" !gdown --id $file_id -O $file_dst\n",
"\n",
"downloader = Downloader(download_with_pydrive)\n",
"\n",
"pre_trained_path = os.path.join('models', 'ldm', 'stable-diffusion-v4')\n",
"os.makedirs(pre_trained_path, exist_ok=True)\n",
"!wget -O models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4-full-ema.ckpt \n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step 2: Download the selected fine-tuned model. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"finetuned_models_dir = os.path.join('./models', 'finetuned')\n",
"os.makedirs(finetuned_models_dir, exist_ok=True)\n",
"\n",
"orig_image_dir = './dataset'\n",
"os.makedirs(orig_image_dir, exist_ok=True)\n",
"\n",
"source_model_type = 'dog w/o patch-based fine-tuning' #@param['dog w/o patch-based fine-tuning', 'dog w/ patch-based fine-tuning', 'Girl with a peral earring', 'Monalisa', 'castle w/o patch-based fine-tuning', 'castle w/ patch-based fine-tuning']\n",
"source_model_download_path = {\"dog w/o patch-based fine-tuning\": \"1jHgkyxrwUXyMR2zBK9WAEWioEP3-F3fd\",\n",
" \"dog w/ patch-based fine-tuning\": \"1YI7c29qBIy83OqJ4ykoAAul6P8uaXmls\",\n",
" \"Girl with a peral earring\": \"1l6GCEfyURKQiCF77ZriYoZRtkOXQCyWD\",\n",
" \"Monalisa\": \"194CDgHkomKrLvgFj89kamoTjUbMwyUIC\",\n",
" \"castle w/o patch-based fine-tuning\": \"19I8ftab9vMQWnqPH2O7aHe-GnYolmVFF\",\n",
" \"castle w/ patch-based fine-tuning\": \"1srzUr1fg6jTFKuf0M5oi5JgBsVhCt-nb\"}\n",
"\n",
"model_names = { \"dog w/o patch-based fine-tuning\": \"dog_wo_patch.ckpt\",\n",
" \"dog w/ patch-based fine-tuning\": \"dog_w_patch.ckpt\",\n",
" \"Girl with a peral earring\": \"girl.ckpt\",\n",
" \"Monalisa\": \"monalisa.ckpt\",\n",
" \"castle w/o patch-based fine-tuning\": \"castle_wo_patch\",\n",
" \"castle w/ patch-based fine-tuning\": \"castle_w_patch\"}\n",
"\n",
"model_configs = { \"dog w/o patch-based fine-tuning\": \"./configs/stable-diffusion/v1-inference.yaml\",\n",
" \"dog w/ patch-based fine-tuning\": \"./configs/stable-diffusion/v1-inference_patch.yaml\",\n",
" \"Girl with a peral earring\": \"./configs/stable-diffusion/v1-inference_patch_nearest.yaml\",\n",
" \"Monalisa\": \"./configs/stable-diffusion/v1-inference_patch_nearest.yaml\",\n",
" \"castle w/o patch-based fine-tuning\": \"./configs/stable-diffusion/v1-inference.yaml\",\n",
" \"castle w/ patch-based fine-tuning\": \"./configs/stable-diffusion/v1-inference_patch.yaml\"}\n",
"\n",
"orig_prompts = { \"dog w/o patch-based fine-tuning\": \"picture of a sks dog\",\n",
" \"dog w/ patch-based fine-tuning\": \"picture of a sks dog\",\n",
" \"Girl with a peral earring\": \"painting of a sks girl\",\n",
" \"Monalisa\": \"painting of a sks lady\",\n",
" \"castle w/o patch-based fine-tuning\": \"picture of a sks castle\",\n",
" \"castle w/ patch-based fine-tuning\": \"picture of a sks castle\"}\n",
"\n",
"download_string = source_model_download_path[source_model_type]\n",
"file_name = model_names[source_model_type]\n",
"\n",
"config_name = model_configs[source_model_type]\n",
"fine_tune_prompt = orig_prompts[source_model_type]\n",
"\n",
"if not os.path.isfile(os.path.join(finetuned_models_dir, file_name)):\n",
" downloader.download_file(download_string, os.path.join(finetuned_models_dir, file_name))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Step3: Edit the image with model-based guidance"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import argparse, os, sys, glob\n",
"import torch\n",
"import numpy as np\n",
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"from tqdm import tqdm, trange\n",
"from itertools import islice\n",
"from einops import rearrange\n",
"from torchvision.utils import make_grid, save_image\n",
"import time\n",
"from pytorch_lightning import seed_everything\n",
"from torch import autocast\n",
"from contextlib import contextmanager, nullcontext\n",
"\n",
"from ldm.util import instantiate_from_config\n",
"from ldm.models.diffusion.ddim import DDIMSampler\n",
"from ldm.models.diffusion.plms import PLMSSampler\n",
"from IPython.display import display\n",
"\n",
"\n",
"def chunk(it, size):\n",
" it = iter(it)\n",
" return iter(lambda: tuple(islice(it, size)), ())\n",
"\n",
"\n",
"def load_model_from_config(config, ckpt, verbose=False):\n",
" print(f\"Loading model from {ckpt}\")\n",
" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
" if \"global_step\" in pl_sd:\n",
" print(f\"Global Step: {pl_sd['global_step']}\")\n",
" sd = pl_sd[\"state_dict\"]\n",
" model = instantiate_from_config(config.model)\n",
" m, u = model.load_state_dict(sd, strict=False)\n",
" if len(m) > 0 and verbose:\n",
" print(\"missing keys:\")\n",
" print(m)\n",
" if len(u) > 0 and verbose:\n",
" print(\"unexpected keys:\")\n",
" print(u)\n",
"\n",
" model.cuda()\n",
" model.eval()\n",
" return model\n",
"\n",
"seed = 42\n",
"config = OmegaConf.load('configs/stable-diffusion/v1-inference.yaml')\n",
"model = load_model_from_config(config, 'models/ldm/stable-diffusion-v4/sd-v1-4-full-ema.ckpt')\n",
"\n",
"device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
"model = model.to(device)\n",
"\n",
"sin_config = OmegaConf.load(f\"{config_name}\")\n",
"sin_model = load_model_from_config(config, os.path.join(finetuned_models_dir, file_name))\n",
"sin_model = sin_model.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = 0.7 #@param {type:\"slider\", min:0, max:1, step:0.05}\n",
"K_min = 400 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
"scale = 7.5 #@param {type:\"slider\", min:1.0, max:50, step:0.5}\n",
"ddim_steps = 100\n",
"ddim_eta = 0.\n",
"H = 512\n",
"W = 512\n",
"\n",
"prompt = \"a dog wearing a superhero cape\" #@param {'type': 'string'}\n",
"\n",
"extra_config = {\n",
" 'cond_beta': v,\n",
" 'cond_beta_sin': 1. - v,\n",
" 'range_t_max': 1000,\n",
" 'range_t_min': K_min\n",
"}\n",
"\n",
"\n",
"from ldm.models.diffusion.guidance_ddim import DDIMSinSampler\n",
"sampler = DDIMSinSampler(model, sin_model)\n",
"\n",
"setattr(sampler.model, 'extra_config', extra_config)\n",
"\n",
"\n",
"batch_size = 1\n",
"n_rows = 2\n",
"start_code = None\n",
"precision_scope = autocast\n",
"num_samples = 4\n",
"\n",
"all_samples = list()\n",
"\n",
"with torch.no_grad():\n",
" with precision_scope(\"cuda\"):\n",
" with model.ema_scope():\n",
" with sin_model.ema_scope():\n",
" tic = time.time()\n",
" all_samples = list()\n",
" for n in trange(num_samples, desc=\"Sampling\"): \n",
" uc = None\n",
" if scale != 1.0:\n",
" uc = model.get_learned_conditioning(batch_size * [\"\"])\n",
" uc_sin = sin_model.get_learned_conditioning(batch_size * [\"\"])\n",
"\n",
" prompts = [prompt] * batch_size\n",
" prompts_single = [fine_tune_prompt] * batch_size\n",
" \n",
" c = model.get_learned_conditioning(prompts)\n",
" c_sin = sin_model.get_learned_conditioning(prompts_single)\n",
" \n",
" shape = [4, H // 8, W // 8]\n",
" samples_ddim, _ = sampler.sample( S=ddim_steps,\n",
" conditioning=c,\n",
" conditioning_single=c_sin,\n",
" batch_size=batch_size,\n",
" shape=shape,\n",
" verbose=False,\n",
" unconditional_guidance_scale=scale,\n",
" unconditional_conditioning=uc,\n",
" unconditional_conditioning_single=uc_sin,\n",
" eta=ddim_eta,\n",
" x_T=start_code)\n",
"\n",
" x_samples_ddim = model.decode_first_stage(samples_ddim)\n",
" x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\n",
"\n",
" all_samples.append(x_samples_ddim)\n",
"\n",
" grid = torch.stack(all_samples, 0)\n",
" grid = rearrange(grid, 'n b c h w -> (n b) c h w')\n",
" \n",
" grid = make_grid(grid, nrow=n_rows)\n",
"\n",
" # to image\n",
" grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()\n",
" os.makedirs('./output', exist_ok=True)\n",
" Image.fromarray(grid.astype(np.uint8)).save(os.path.join('./output', f'{prompt.replace(\" \", \"-\")}.jpg'))\n",
" display(Image.open(os.path.join('./output', f'{prompt.replace(\" \", \"-\")}.jpg')))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.0 ('SINE')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.0"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a84f578b9cda1db545aa6690161d7775d6ea32a647f25bb9ef4866c136688289"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: configs/stable-diffusion/v1-finetune_painting.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 0.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-06
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.0
reg:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.0
validation:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 800
================================================
FILE: configs/stable-diffusion/v1-finetune_painting_style.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 0.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-06
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.0
learn_style: true
reg:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.0
validation:
target: ldm.data.personalized_painting.SinImageDataset
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
learn_style: true
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 800
================================================
FILE: configs/stable-diffusion/v1-finetune_patch_painting.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 0.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-6
scale_recon_loss: 1.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
padding_idx: 0
init_size: 128
div_half_dim: false
center_shift: 100
interpolation_mode: "nearest"
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized_painting.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: train
per_image_tokens: false
repeats: 100
min_crop_frac: 0.1
max_crop_frac: 1.0
rec_prob: 1.
latent_scale: 8
reg:
target: ldm.data.personalized_painting.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: train
per_image_tokens: false
repeats: 1
min_crop_frac: 0.1
max_crop_frac: 1.0
rec_prob: 0.
latent_scale: 8
validation:
target: ldm.data.personalized_painting.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: val
per_image_tokens: false
repeats: 10
min_crop_frac: 0.2
max_crop_frac: 1.0
rec_prob: 0.25
latent_scale: 8
lightning:
modelcheckpoint:
params:
every_n_train_steps: 2000
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 10000
================================================
FILE: configs/stable-diffusion/v1-finetune_patch_picture.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 0.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-6
scale_recon_loss: 1.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
padding_idx: 0
init_size: 128
div_half_dim: false
center_shift: 100
interpolation_mode: "bilinear"
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: train
per_image_tokens: false
repeats: 100
min_crop_frac: 0.1
max_crop_frac: 1.0
rec_prob: 1.
latent_scale: 8
reg:
target: ldm.data.personalized.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: train
per_image_tokens: false
repeats: 1
min_crop_frac: 0.1
max_crop_frac: 1.0
rec_prob: 0.
latent_scale: 8
validation:
target: ldm.data.personalized.SinImageHighResDataset
params:
size: 512
high_resolution: 1024
set: val
per_image_tokens: false
repeats: 10
min_crop_frac: 0.2
max_crop_frac: 1.0
rec_prob: 0.25
latent_scale: 8
lightning:
modelcheckpoint:
params:
every_n_train_steps: 7000
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 2000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 10000
================================================
FILE: configs/stable-diffusion/v1-finetune_picture.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 0.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-06
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.
reg:
target: ldm.data.personalized.SinImageDataset
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
flip_p: 0.
validation:
target: ldm.data.personalized.SinImageDataset
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 800
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 800
================================================
FILE: configs/stable-diffusion/v1-inference.yaml
================================================
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
================================================
FILE: configs/stable-diffusion/v1-inference_patch.yaml
================================================
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
padding_idx: 0
init_size: 128 ## Note: Might be some problem in this line. Might need to be 1024
div_half_dim: false
center_shift: 100
interpolation_mode: "bilinear" # bilinear or nearest supported
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
================================================
FILE: configs/stable-diffusion/v1-inference_patch_nearest_interp.yaml
================================================
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModelPatch
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
padding_idx: 0
init_size: 128 ## Note: Might be some problem in this line. Might need to be 1024
div_half_dim: false
center_shift: 100
interpolation_mode: "nearest" # bilinear or nearest supported
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
================================================
FILE: diffusers_models.py
================================================
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
from diffusers.configuration_utils import register_to_config
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import StrictInt, StrictFloat, StrictBool, StrictStr
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from ldm.modules.diffusionmodules.positional_encoding import SinusoidalPositionalEmbedding
class UNet2DConditionPatchModel(UNet2DConditionModel):
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
time_embedding_type: str = "positional", # fourier, positional
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
padding_idx: StrictInt = 0,
init_size: StrictInt = 128,
div_half_dim: StrictBool = False,
center_shift: StrictInt = 64,
interpolation_mode: StrictStr = "bilinear",
):
super().__init__(sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
center_input_sample=center_input_sample,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
up_block_types=up_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
time_embedding_type=time_embedding_type, # fourier, positional
timestep_post_act=timestep_post_act,
time_cond_proj_dim=time_cond_proj_dim,
conv_in_kernel=conv_in_kernel,
conv_out_kernel=conv_out_kernel,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
)
assert block_out_channels[0] % 2 == 0
self.head_position_encode = SinusoidalPositionalEmbedding(embedding_dim=block_out_channels[0]//2,
padding_idx=padding_idx,
init_size=init_size,
div_half_dim=div_half_dim,
center_shift=center_shift)
self.init_size = init_size
self.interpolation_mode = interpolation_mode
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
crop_boxes: Optional[torch.Tensor] = None,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor(
[timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError(
"class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# 2. pre-process
sample = self.conv_in(sample)
head_grid = self.head_position_encode(torch.ones([sample.shape[0], sample.shape[1], self.init_size, self.init_size], dtype=self.dtype,
device=sample.device))
if crop_boxes is not None:
head_grid = torch.cat([F.interpolate(hg.unsqueeze(0)[:, :, box[0]: box[2], box[1]: box[3]],
(sample.shape[2], sample.shape[3]), mode='bilinear', align_corners=True)
for hg, box in
zip(head_grid, crop_boxes)], dim=0)
else:
head_grid = F.interpolate(
head_grid, (sample.shape[2], sample.shape[3]), mode='bilinear', align_corners=True)
sample += head_grid
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(
upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
if __name__ == "__main__":
unet = UNet2DConditionPatchModel.from_pretrained(
"CompVis/stable-diffusion-v1-4", subfolder="unet", revision=None, low_cpu_mem_usage=False, device_map=None
)
================================================
FILE: diffusers_sample.py
================================================
from diffusers import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
import torch
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from transformers import AutoTokenizer, PretrainedConfig
from typing import Any, Callable, Dict, List, Optional, Union
import importlib
import os
import diffusers
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
import argparse
from accelerate.utils import ProjectConfiguration, set_seed
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
class StableDiffusionGuidancePipeline(StableDiffusionPipeline):
text_encoder_orig = None
unet_orig = None
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,
):
super().__init__(vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker,
feature_extractor,
requires_safety_checker,)
self.config['unet'] = (unet.__module__, unet.config['_class_name'])
def add_pretrained_model(self, text_encoder, unet):
self.text_encoder_orig = text_encoder.to(self._execution_device)
self.unet_orig = unet.to(self._execution_device)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
torch_dtype = kwargs.pop("torch_dtype", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder)
if config_dict['unet'][0] is None:
config_dict['unet'][0] = 'diffusers_models'
# 2. Load the pipeline class
pipeline_class = cls
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if isinstance(value, bool):
return False
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
# pass
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
# else we just import it from the library.
# NOTE: here I reuse library_name as the module name
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
if loaded_sub_model is None:
load_method_name = 'from_pretrained'
load_method = getattr(class_obj, load_method_name)
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
# if issubclass(class_obj, diffusers.OnnxRuntimeModel):
# loading_kwargs["provider"] = provider
# loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
init_kwargs['requires_safety_checker'] = False
# 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
if return_cached_folder:
return model, cached_folder
return model
def _encode_prompt_orig(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
if hasattr(self.text_encoder_orig.config, "use_attention_mask") and self.text_encoder_orig.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
prompt_embeds = self.text_encoder_orig(
text_input_ids.to(device),
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_orig.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder_orig.config, "use_attention_mask") and self.text_encoder_orig.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder_orig(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_orig.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_edit: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
model_based_guidance_scale: float = 0.0,
K_min: int = 400,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
self.check_inputs(
prompt_edit, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
assert isinstance(prompt_edit, str)
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
assert isinstance(prompt_edit, list)
assert len(prompt_edit) == len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
prompt_embdes_edit = self._encode_prompt_orig(
prompt_edit,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=None,
negative_prompt_embeds=None,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet_orig(
latent_model_input,
t,
encoder_hidden_states=prompt_embdes_edit,
cross_attention_kwargs=cross_attention_kwargs,
).sample
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if t > K_min and model_based_guidance_scale > 0.0:
noise_pred_guidance = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
).sample
if do_classifier_free_guidance:
noise_pred_guidance_uncond, noise_pred_guidance_text = noise_pred_guidance.chunk(2)
noise_pred_text = noise_pred_text * (1 - model_based_guidance_scale) + noise_pred_guidance_text * model_based_guidance_scale
else:
noise_pred = noise_pred * (1 - model_based_guidance_scale) + noise_pred_guidance * model_based_guidance_scale
# perform guidance
if do_classifier_free_guidance:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, None)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
def parse_args(input_args=None):
parser = argparse.ArgumentParser(
description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
'--prompt',
type=str,
default=None,
required=True,
help='Text prompt for the fine-tuned model.'
)
parser.add_argument(
'--editing_prompt',
type=str,
default=None,
required=True,
help='Text prompt for the pre-trained model.'
)
parser.add_argument("--seed", type=int, default=412441,
help="A seed for reproducible sampling.")
parser.add_argument('--num_images_per_prompt', type=int, default=2, help='Batch size.')
parser.add_argument('--num_iterations', type=int, default=1,)
parser.add_argument('--model_based_guidance_scale', type=float, default=0.3, help='Scale of model-based guidance.')
parser.add_argument('--guidance_scale', type=float, default=7.5, help='Scale of classifier-free guidance.')
parser.add_argument('--K', default=400, type=int, help='step to stop guidance')
parser.add_argument('--ddim_steps', default=100, type=int, help='Number of ddim steps')
parser.add_argument('--height', default=512, type=int, help='Height of the image')
parser.add_argument('--width', default=512, type=int, help='Width of the image')
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
pretrained_model_name_or_path = "CompVis/stable-diffusion-v1-4"
text_encoder_cls = import_model_class_from_model_name_or_path(pretrained_model_name_or_path, None)
text_encoder = text_encoder_cls.from_pretrained(
pretrained_model_name_or_path, subfolder="text_encoder", revision=None
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", revision=None
)
model_id = args.pretrained_model_name_or_path
pipe = StableDiffusionGuidancePipeline.from_pretrained(model_id, torch_dtype=torch.float).to("cuda")
pipe.add_pretrained_model(text_encoder=text_encoder, unet=unet)
prompt = args.prompt
prompt_edit = args.editing_prompt
file_name = (prompt + '[SEP]' + prompt_edit).replace(' ', '_')
for i in range(args.num_iterations):
images = pipe(prompt=prompt, prompt_edit=prompt_edit,
model_based_guidance_scale=args.model_based_guidance_scale,
K_min=args.K, num_inference_steps=args.ddim_steps,
guidance_scale=args.guidance_scale,
num_images_per_prompt=args.num_images_per_prompt, height=args.height, width=args.width).images
for j, image in enumerate(images):
image.save("{}/{}_{}.png".format(args.pretrained_model_name_or_path, file_name, i * args.num_images_per_prompt + j))
================================================
FILE: diffusers_train.py
================================================
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import hashlib
import itertools
import logging
import math
import os
import warnings
from pathlib import Path
from typing import Optional
import random
from copy import deepcopy
import accelerate
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from packaging import version
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
import numpy as np
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from diffusers_models import UNet2DConditionPatchModel
logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(
description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--img_path",
type=str,
default=None,
required=True,
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=412441,
help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=(
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
" for more details"
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0,
help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9,
help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999,
help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float,
default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08,
help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0,
type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true",
help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None,
help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser.add_argument(
"--patch_based_training",
action="store_true",
help=(
"activate patch based training"
),
)
parser.add_argument(
'--high_res',
type=int,
default=1024,
help='the highest resolution provided to the dataloader'
)
parser.add_argument(
'--latent_scale',
type=int,
default=8,
help='the scale of the latent space'
)
parser.add_argument(
'--min_crop_frac',
type=float,
default=0.1,
help='the minimum fraction of the image to crop'
)
parser.add_argument(
'--max_crop_frac',
type=float,
default=1,
help='the maximum fraction of the image to crop'
)
parser.add_argument(
'--rec_prob',
type=float,
default=0.1,
help='the probability of using the whole image as the crop'
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
class SINEDatasetPatch(Dataset):
def __init__(
self,
img_path,
instance_prompt,
tokenizer,
size=512,
high_res=1024,
min_crop_frac=0.1,
max_crop_frac=1,
rec_prob=0.1,
latent_scale=8,
):
super().__init__()
self.size = size
self.tokenizer = tokenizer
self.img_path = Path(img_path)
if not self.img_path.exists():
raise ValueError(f"Image {self.img_path} doesn't exists.")
self.num_instance_images = 1
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
self.image = Image.open(img_path)
if not self.image.mode == "RGB":
self.image = self.image.convert("RGB")
self.image = self.image.resize((high_res, high_res), resample=Image.Resampling.BICUBIC)
self.instance_prompt_ids = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
self.high_res = high_res
self.min_crop_frac = min_crop_frac
self.max_crop_frac = max_crop_frac
self.rec_prob = rec_prob
self.latent_scale = latent_scale
self.image_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length * 20
def _random_crop(self, pil_image):
patch_size_y = int(
(self.high_res//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
patch_size_x = int(
(self.high_res//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
crop_y = random.randrange((self.high_res//self.latent_scale) - patch_size_y + 1)
crop_x = random.randrange((self.high_res//self.latent_scale) - patch_size_x + 1)
return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(
(self.size, self.size),
resample=Image.Resampling.BICUBIC), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x
def __getitem__(self, i):
example = {}
image = deepcopy(self.image)
if random.random() > self.rec_prob:
image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)
crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])
else:
image = image.resize((self.size, self.size), resample=Image.Resampling.BICUBIC )
crop_area = torch.tensor([0, 0, self.high_res//self.latent_scale, self.high_res//self.latent_scale])
example = {
"instance_images": self.image_transform(image),
"instance_prompt_ids": self.instance_prompt_ids,
"crop_area": crop_area,
}
return example
class SINEDatasetSingleRes(Dataset):
def __init__(
self,
img_path,
instance_prompt,
tokenizer,
size=512,
):
super().__init__()
self.size = size
self.tokenizer = tokenizer
self.img_path = Path(img_path)
if not self.img_path.exists():
raise ValueError(f"Image {self.img_path} doesn't exists.")
self.num_instance_images = 1
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
self.image = Image.open(img_path)
if not self.image.mode == "RGB":
self.image = self.image.convert("RGB")
self.image_transform = transforms.Compose(
[
transforms.Resize(
(size, size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.image = self.image_transform(self.image)
self.instance_prompt_ids = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
def __len__(self):
return self._length * 20
def __getitem__(self, index):
example = {
"instance_images": self.image,
"instance_prompt_ids": self.instance_prompt_ids,
}
return example
def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
if "crop_area" in examples[0]:
crop_area = [example["crop_area"] for example in examples]
crop_area = torch.stack(crop_area)
# crop_area = crop_area.to(memory_format=torch.contiguous_format).float()
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"crop_area": crop_area,
}
return batch
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
total_limit=args.checkpoints_total_limit)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
project_config=accelerator_project_config,
)
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(
args.output_dir, clone_from=repo_name, token=args.hub_token)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
if args.patch_based_training:
unet = UNet2DConditionPatchModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False, device_map=None
)
else:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "unet" if type(model) == type(
unet) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
if type(model) == type(text_encoder):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(
input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(
input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly")
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# Check that all trainable models are in full precision
low_precision_error_string = (
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps *
args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = (
itertools.chain(unet.parameters(), text_encoder.parameters(
)) if args.train_text_encoder else unet.parameters()
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset and DataLoaders creation:
train_dataset = SINEDatasetSingleRes(
img_path=args.img_path,
instance_prompt=args.instance_prompt,
tokenizer=tokenizer,
size=args.resolution,
) if not args.patch_based_training else SINEDatasetPatch(
img_path=args.img_path,
instance_prompt=args.instance_prompt,
tokenizer=tokenizer,
size=args.resolution,
high_res=args.high_res,
min_crop_frac=args.min_crop_frac,
max_crop_frac=args.max_crop_frac,
rec_prob=args.rec_prob,
latent_scale=args.latent_scale,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples),
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(
args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * \
accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(
f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(
f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (
num_update_steps_per_epoch * args.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps),
disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(
dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(
latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps,encoder_hidden_states).sample if not args.patch_based_training else \
unet(noisy_latents, timesteps,encoder_hidden_states, crop_boxes=batch['crop_area']).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(
latents, noise, timesteps)
else:
raise ValueError(
f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(),
target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
itertools.chain(unet.parameters(),
text_encoder.parameters())
if args.train_text_encoder
else unet.parameters()
)
accelerator.clip_grad_norm_(
params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(
args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(
), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training",
blocking=False, auto_lfs_prune=True)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
================================================
FILE: environment.yml
================================================
name: SINE
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- ca-certificates=2022.10.11=h06a4308_0
- certifi=2022.9.24=py38h06a4308_0
- libedit=3.1.20210910=h7f8727e_0
- libffi=3.2.1=hf484d3e_1007
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- ncurses=6.3=h5eee18b_3
- openssl=1.1.1s=h7f8727e_0
- pip=22.2.2=py38h06a4308_0
- python=3.8.0=h0371630_2
- readline=7.0=h7b6447c_5
- sqlite=3.33.0=h62c20be_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.37.1=pyhd3eb1b0_0
- xz=5.2.8=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- absl-py==1.3.0
- accelerate==0.16.0
- aiohttp==3.8.3
- aiosignal==1.3.1
- albumentations==1.1.0
- antlr4-python3-runtime==4.8
- asttokens==2.2.1
- async-timeout==4.0.2
- attrs==22.1.0
- autopep8==2.0.0
- backcall==0.2.0
- cachetools==5.2.0
- charset-normalizer==2.1.1
- debugpy==1.6.4
- decorator==5.1.1
- einops==0.6.0
- entrypoints==0.4
- executing==1.2.0
- filelock==3.8.2
- frozenlist==1.3.3
- fsspec==2022.11.0
- ftfy==6.1.1
- future==0.18.2
- google-auth==2.15.0
- google-auth-oauthlib==0.4.6
- grpcio==1.51.1
- huggingface-hub==0.11.1
- idna==3.4
- imageio==2.14.1
- imageio-ffmpeg==0.4.7
- importlib-metadata==5.1.0
- ipdb==0.13.9
- ipykernel==6.17.1
- ipython==8.7.0
- jedi==0.18.2
- jinja2==3.1.2
- joblib==1.2.0
- jupyter-client==7.4.8
- jupyter-core==5.1.0
- kornia==0.6.0
- markdown==3.4.1
- markupsafe==2.1.1
- matplotlib-inline==0.1.6
- multidict==6.0.3
- nest-asyncio==1.5.6
- networkx==2.8.8
- numpy==1.23.5
- oauthlib==3.2.2
- omegaconf==2.1.1
- opencv-python==4.2.0.34
- opencv-python-headless==4.6.0.66
- packaging==21.3
- pandas==1.5.2
- parso==0.8.3
- pexpect==4.8.0
- pickleshare==0.7.5
- pillow==9.3.0
- platformdirs==2.6.0
- prompt-toolkit==3.0.36
- protobuf==3.20.3
- psutil==5.9.4
- ptyprocess==0.7.0
- pudb==2019.2
- pure-eval==0.2.2
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pycodestyle==2.10.0
- pydantic==1.10.2
- pydeprecate==0.3.1
- pygments==2.13.0
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pytorch-lightning==1.5.9
- pytz==2022.6
- pywavelets==1.4.1
- pyyaml==6.0
- pyzmq==24.0.1
- qudida==0.0.4
- regex==2022.10.31
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.9
- scikit-image==0.19.3
- scikit-learn==1.1.3
- scipy==1.9.3
- setuptools==59.5.0
- six==1.16.0
- stack-data==0.6.2
- tensorboard==2.11.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- test-tube==0.7.5
- threadpoolctl==3.1.0
- tifffile==2022.10.10
- tokenizers==0.13.2
- toml==0.10.2
- tomli==2.0.1
- torch==1.11.0+cu113
- torchmetrics==0.11.0
- torchvision==0.12.0+cu113
- tornado==6.2
- tqdm==4.64.1
- traitlets==5.6.0
- transformers==4.25.1
- typing-extensions==4.5.0
- urllib3==1.26.13
- urwid==2.1.2
- wcwidth==0.2.5
- werkzeug==2.2.2
- yarl==1.8.2
- zipp==3.11.0
================================================
FILE: ldm/data/__init__.py
================================================
================================================
FILE: ldm/data/base.py
================================================
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
================================================
FILE: ldm/data/personalized.py
================================================
import os
import numpy as np
import re
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from copy import deepcopy
import random
import torch
training_templates_smallest = [
'photo of a sks {}',
]
reg_templates_smallest = [
'photo of a {}',
]
imagenet_templates_small = [
'a photo of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a photo of a clean {}',
'a photo of a dirty {}',
'a dark photo of the {}',
'a photo of my {}',
'a photo of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a photo of the {}',
'a good photo of the {}',
'a photo of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a photo of the clean {}',
'a rendition of a {}',
'a photo of a nice {}',
'a good photo of a {}',
'a photo of the nice {}',
'a photo of the small {}',
'a photo of the weird {}',
'a photo of the large {}',
'a photo of a cool {}',
'a photo of a small {}',
'an illustration of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'an illustration of a clean {}',
'an illustration of a dirty {}',
'a dark photo of the {}',
'an illustration of my {}',
'an illustration of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'an illustration of the {}',
'a good photo of the {}',
'an illustration of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'an illustration of the clean {}',
'a rendition of a {}',
'an illustration of a nice {}',
'a good photo of a {}',
'an illustration of the nice {}',
'an illustration of the small {}',
'an illustration of the weird {}',
'an illustration of the large {}',
'an illustration of a cool {}',
'an illustration of a small {}',
'a depiction of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a depiction of a clean {}',
'a depiction of a dirty {}',
'a dark photo of the {}',
'a depiction of my {}',
'a depiction of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a depiction of the {}',
'a good photo of the {}',
'a depiction of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a depiction of the clean {}',
'a rendition of a {}',
'a depiction of a nice {}',
'a good photo of a {}',
'a depiction of the nice {}',
'a depiction of the small {}',
'a depiction of the weird {}',
'a depiction of the large {}',
'a depiction of a cool {}',
'a depiction of a small {}',
]
imagenet_dual_templates_small = [
'a photo of a {} with {}',
'a rendering of a {} with {}',
'a cropped photo of the {} with {}',
'the photo of a {} with {}',
'a photo of a clean {} with {}',
'a photo of a dirty {} with {}',
'a dark photo of the {} with {}',
'a photo of my {} with {}',
'a photo of the cool {} with {}',
'a close-up photo of a {} with {}',
'a bright photo of the {} with {}',
'a cropped photo of a {} with {}',
'a photo of the {} with {}',
'a good photo of the {} with {}',
'a photo of one {} with {}',
'a close-up photo of the {} with {}',
'a rendition of the {} with {}',
'a photo of the clean {} with {}',
'a rendition of a {} with {}',
'a photo of a nice {} with {}',
'a good photo of a {} with {}',
'a photo of the nice {} with {}',
'a photo of the small {} with {}',
'a photo of the weird {} with {}',
'a photo of the large {} with {}',
'a photo of a cool {} with {}',
'a photo of a small {} with {}',
]
reg_templates_smallest = [
'photo of a {}',
]
reg_templates_no_class_smallest = [
'a photo',
]
reg_templates_no_class_small = [
'a photo',
'a rendering',
'a cropped photo',
'the photo',
'a dark photo',
'a close-up photo',
'a bright photo',
'a cropped photo',
'a good photo',
'a rendition',
'an illustration',
'a depiction',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg = False
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if not self.reg:
text = random.choice(training_templates_smallest).format(placeholder_string)
else:
text = random.choice(reg_templates_smallest).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
def crop_image(img, size=512, cropping='random', crop_scale=[1, 1]):
if cropping in ['random', 'random_long_edge']:
h, w, = img.shape[0], img.shape[1]
crop = min(h, w)
crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)
offset_h = np.random.randint(0, h - crop + 1)
offset_w = np.random.randint(0, w - crop + 1)
img = img[offset_h:offset_h + crop, offset_w:offset_w + crop]
elif cropping in ['center', 'center_long_edge']:
h, w, = img.shape[0], img.shape[1]
crop = min(h, w)
crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)
img = img[(h - crop) // 2:(h + crop) // 2, (w - crop) // 2:(w + crop) // 2]
elif cropping == 'crop':
h, w, = img.shape[0], img.shape[1]
crop = min(h, w, size)
crop = int(torch.empty(1).uniform_(crop_scale[0], crop_scale[1]).item() * crop)
offset_h = np.random.randint(0, h - crop + 1)
offset_w = np.random.randint(0, w - crop + 1)
img = img[offset_h:offset_h + crop, offset_w:offset_w + crop]
else:
raise NotImplementedError
return img
class PersonalizedMulti(Dataset):
def __init__(
self,
data_root="",
size=None,
repeats=100,
interpolation="lanczos",
flip_p=0.5,
which_set="train",
placeholder_token="sks",
per_image_tokens=False,
cropping='random_long_edge',
crop_scale=[1, 1],
mixing_prob=0.25,
coarse_class_text=None,
reg=False,
use_small_template=False,
delimiters = ",|:|;",
**kwargs,
):
# NOTE: split str to list
data_root = re.split(delimiters, data_root)
placeholder_token = re.split(delimiters, placeholder_token)
# import ipdb
# ipdb.set_trace()
if coarse_class_text:
coarse_class_text = re.split(delimiters, coarse_class_text)
coarse_class_text = [(None if (s in ['none', 'None', 'null', 'Null']) else s) for s in coarse_class_text]
else:
coarse_class_text = [None] * len(data_root)
assert len(placeholder_token) == len(data_root) == len(coarse_class_text)
self.keys = placeholder_token
self.data_root = {k: v for k, v in zip(self.keys, data_root)}
self.placeholder_token = {k: v for k, v in zip(self.keys, placeholder_token)}
self.coarse_class_text = {k: v for k, v in zip(self.keys, coarse_class_text)}
self.image_paths = {k: [os.path.join(self.data_root[k], file_path) for file_path in os.listdir(self.data_root[k])] if os.path.isdir(self.data_root[k])else [self.data_root[k]] for k in self.keys}
self.num_images = {k: len(self.image_paths[k]) for k in self.keys}
self._length = max([self.num_images[k] for k in self.keys])
if which_set == "train":
self._length = self._length * repeats
self.reg = reg
# self.per_image_tokens = per_image_tokens
self.cropping = cropping
self.crop_scale = crop_scale
self.mixing_prob = mixing_prob
self.use_small_template = use_small_template
self.templates = {
k: self.setup_templates(
placeholder_token=self.placeholder_token[k],
coarse_class_text=self.coarse_class_text[k],
reg=reg, use_small_template=use_small_template,
) for k in self.keys
}
self.size = size
self.interpolation = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def setup_templates(self, placeholder_token='sks', coarse_class_text='dog', reg=False, use_small_template=False):
if reg: # NOTE: reg dataset
if coarse_class_text:
placeholder_string = f"{coarse_class_text}"
templates = imagenet_templates_small if use_small_template else reg_templates_smallest
templates = [t.format(placeholder_string) for t in templates]
else:
templates = reg_templates_no_class_small if use_small_template else reg_templates_no_class_smallest
else: # NOTE: train dataset
if coarse_class_text:
placeholder_string = f"{placeholder_token} {coarse_class_text}"
else:
placeholder_string = f"{placeholder_token}"
templates = imagenet_templates_small if use_small_template else reg_templates_smallest
templates = [t.format(placeholder_string) for t in templates]
return templates
def __len__(self):
return self._length
def __getitem__(self, i):
key = random.choice(self.keys)
example = {}
image = Image.open(self.image_paths[key][i % self.num_images[key]])
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
img = crop_image(img, size=self.size, cropping=self.cropping, crop_scale=self.crop_scale)
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
text = random.choice(self.templates[key])
example["caption"] = text.rstrip()
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example["label"] = [self.keys.index(key)]
# import ipdb
# ipdb.set_trace()
return example
class SinImageDataset(PersonalizedBase):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg = False
):
self.data_root = data_root
assert os.path.isfile(self.data_root), f"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}."
self.image_paths = [self.data_root]*100
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
class SinImageHighResDataset(Dataset):
def __init__(self,
data_root,
size=512,
high_resolution=1024,
latent_scale=8,
min_crop_frac=0.5,
max_crop_frac=1.0,
rec_prob=0.0,
repeats=100,
interpolation="bicubic",
flip_p=0.,
set="train",
placeholder_token="dog",
per_image_tokens=False,
mixing_prob=0.25,
coarse_class_text=None):
self.data_root = data_root
assert os.path.isfile(self.data_root), f"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}."
self.num_images = 100
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.high_resolution = high_resolution
self.min_crop_frac = min_crop_frac
self.max_crop_frac = max_crop_frac
self.rec_prob = rec_prob
self.latent_scale = latent_scale
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
image = Image.open(self.data_root)
if not image.mode == "RGB":
image = image.convert("RGB")
self.image = image.resize((self.high_resolution, self.high_resolution), resample=self.interpolation)
def __len__(self):
return self._length
def _random_crop(self, pil_image):
patch_size_y = int(
(self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
patch_size_x = int(
(self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
crop_y = random.randrange((self.high_resolution//self.latent_scale) - patch_size_y + 1)
crop_x = random.randrange((self.high_resolution//self.latent_scale) - patch_size_x + 1)
return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(
(self.size, self.size),
resample=self.interpolation), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x
def __getitem__(self, i):
example = {}
image = deepcopy(self.image)
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
text = random.choice(training_templates_smallest).format(placeholder_string)
example["caption"] = text
if random.random() < self.rec_prob:
image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)
crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])
else:
image = image.resize((self.size, self.size), resample=self.interpolation)
crop_area = torch.tensor([0, 0, self.high_resolution//self.latent_scale, self.high_resolution//self.latent_scale])
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['crop_boxes'] = crop_area
return example
================================================
FILE: ldm/data/personalized_painting.py
================================================
import os
import numpy as np
import re
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from copy import deepcopy
import random
import torch
training_templates_smallest = [
'painting of a sks {}',
]
style_templates_smallest = [
'{} in the sks style',
]
reg_templates_smallest = [
'painting of a {}',
]
reg_templates_smallest = [
'photo of a {}',
]
reg_templates_no_class_smallest = [
'a photo',
]
reg_templates_no_class_small = [
'a photo',
'a rendering',
'a cropped photo',
'the photo',
'a dark photo',
'a close-up photo',
'a bright photo',
'a cropped photo',
'a good photo',
'a rendition',
'an illustration',
'a depiction',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg=False,
learn_style=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(
self.data_root, file_path) for file_path in os.listdir(self.data_root)]
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
self.learn_style = learn_style
if per_image_tokens:
assert self.num_images < len(
per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if not self.reg:
if not self.learn_style:
text = random.choice(training_templates_smallest).format(
placeholder_string)
else:
text = random.choice(style_templates_smallest).format(
placeholder_string)
else:
text = random.choice(reg_templates_smallest).format(
placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size),
resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
class SinImageDataset(PersonalizedBase):
def __init__(
self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg=False,
learn_style=False,
):
self.data_root = data_root
assert os.path.isfile(
self.data_root), f"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}."
self.image_paths = [self.data_root]*100
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
self.learn_style = learn_style
if per_image_tokens:
assert self.num_images < len(
per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
class SinImageHighResDataset(Dataset):
def __init__(self,
data_root,
size=512,
high_resolution=1024,
latent_scale=8,
min_crop_frac=0.5,
max_crop_frac=1.0,
rec_prob=0.0,
repeats=100,
interpolation="bicubic",
flip_p=0.,
set="train",
placeholder_token="dog",
per_image_tokens=False,
mixing_prob=0.25,
coarse_class_text=None):
self.data_root = data_root
assert os.path.isfile(
self.data_root), f"SinImageDataset requires a path to a image file, not a directory. Got {self.data_root}."
self.num_images = 100
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(
per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.high_resolution = high_resolution
self.min_crop_frac = min_crop_frac
self.max_crop_frac = max_crop_frac
self.rec_prob = rec_prob
self.latent_scale = latent_scale
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
image = Image.open(self.data_root)
if not image.mode == "RGB":
image = image.convert("RGB")
self.image = image.resize(
(self.high_resolution, self.high_resolution), resample=self.interpolation)
def __len__(self):
return self._length
def _random_crop(self, pil_image):
patch_size_y = int(
(self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
patch_size_x = int(
(self.high_resolution//self.latent_scale) * (random.random() * (self.max_crop_frac - self.min_crop_frac) + self.min_crop_frac))
crop_y = random.randrange(
(self.high_resolution//self.latent_scale) - patch_size_y + 1)
crop_x = random.randrange(
(self.high_resolution//self.latent_scale) - patch_size_x + 1)
return pil_image.crop((crop_x * self.latent_scale, crop_y * self.latent_scale, (crop_x + patch_size_x) * self.latent_scale, (crop_y + patch_size_y) * self.latent_scale)).resize(
(self.size, self.size),
resample=self.interpolation), crop_y, crop_x, crop_y + patch_size_y, crop_x + patch_size_x
def __getitem__(self, i):
example = {}
image = deepcopy(self.image)
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
text = random.choice(training_templates_smallest).format(
placeholder_string)
example["caption"] = text
if random.random() < self.rec_prob:
image, crop_y, crop_x, crop_y1, crop_x1 = self._random_crop(image)
crop_area = torch.tensor([crop_y, crop_x, crop_y1, crop_x1])
else:
image = image.resize((self.size, self.size),
resample=self.interpolation)
crop_area = torch.tensor(
[0, 0, self.high_resolution//self.latent_scale, self.high_resolution//self.latent_scale])
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
example['crop_boxes'] = crop_area
return example
================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f
================================================
FILE: ldm/modules/attention.py
================================================
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from ldm.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c')
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
================================================
FILE: ldm/modules/diffusionmodules/__init__.py
================================================
================================================
FILE: ldm/modules/diffusionmodules/model.py
================================================
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from ldm.util import instantiate_from_config
from ldm.modules.attention import LinearAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1)
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.conv_out = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
z_channels=intermediate_chn, double_z=False, resolution=resolution,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels*ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
out_channels=tmp_chn, depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinea
gitextract_l03wpjs2/ ├── .gitignore ├── LICENSE ├── README.md ├── SINE.ipynb ├── configs/ │ └── stable-diffusion/ │ ├── v1-finetune_painting.yaml │ ├── v1-finetune_painting_style.yaml │ ├── v1-finetune_patch_painting.yaml │ ├── v1-finetune_patch_picture.yaml │ ├── v1-finetune_picture.yaml │ ├── v1-inference.yaml │ ├── v1-inference_patch.yaml │ └── v1-inference_patch_nearest_interp.yaml ├── diffusers_models.py ├── diffusers_sample.py ├── diffusers_train.py ├── environment.yml ├── ldm/ │ ├── data/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── personalized.py │ │ └── personalized_painting.py │ ├── lr_scheduler.py │ ├── modules/ │ │ ├── attention.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── positional_encoding.py │ │ │ └── util.py │ │ ├── distributions/ │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── embedding_manager.py │ │ ├── encoders/ │ │ │ ├── __init__.py │ │ │ ├── modules.py │ │ │ └── modules_bak.py │ │ ├── image_degradation/ │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ └── utils_image.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── main.py ├── scripts/ │ ├── download_first_stages.sh │ ├── download_models.sh │ ├── sample_diffusion.py │ ├── score.py │ ├── stable_txt2img_guidance.py │ └── stable_txt2img_multi_guidance.py └── setup.py
SYMBOL INDEX (578 symbols across 29 files)
FILE: diffusers_models.py
class UNet2DConditionPatchModel (line 15) | class UNet2DConditionPatchModel(UNet2DConditionModel):
method __init__ (line 17) | def __init__(
method forward (line 105) | def forward(
FILE: diffusers_sample.py
function import_model_class_from_model_name_or_path (line 16) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
class StableDiffusionGuidancePipeline (line 36) | class StableDiffusionGuidancePipeline(StableDiffusionPipeline):
method __init__ (line 39) | def __init__(
method add_pretrained_model (line 61) | def add_pretrained_model(self, text_encoder, unet):
method from_pretrained (line 66) | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
method _encode_prompt_orig (line 190) | def _encode_prompt_orig(
method __call__ (line 325) | def __call__(
function parse_args (line 482) | def parse_args(input_args=None):
FILE: diffusers_train.py
function import_model_class_from_model_name_or_path (line 57) | def import_model_class_from_model_name_or_path(pretrained_model_name_or_...
function parse_args (line 77) | def parse_args(input_args=None):
class SINEDatasetPatch (line 368) | class SINEDatasetPatch(Dataset):
method __init__ (line 369) | def __init__(
method __len__ (line 418) | def __len__(self):
method _random_crop (line 421) | def _random_crop(self, pil_image):
method __getitem__ (line 432) | def __getitem__(self, i):
class SINEDatasetSingleRes (line 453) | class SINEDatasetSingleRes(Dataset):
method __init__ (line 454) | def __init__(
method __len__ (line 493) | def __len__(self):
method __getitem__ (line 496) | def __getitem__(self, index):
function collate_fn (line 505) | def collate_fn(examples):
function get_full_repo_name (line 535) | def get_full_repo_name(model_id: str, organization: Optional[str] = None...
function main (line 545) | def main(args):
FILE: ldm/data/base.py
class Txt2ImgIterableBaseDataset (line 5) | class Txt2ImgIterableBaseDataset(IterableDataset):
method __init__ (line 9) | def __init__(self, num_records=0, valid_ids=None, size=256):
method __len__ (line 18) | def __len__(self):
method __iter__ (line 22) | def __iter__(self):
FILE: ldm/data/personalized.py
class PersonalizedBase (line 163) | class PersonalizedBase(Dataset):
method __init__ (line 164) | def __init__(self,
method __len__ (line 210) | def __len__(self):
method __getitem__ (line 213) | def __getitem__(self, i):
function crop_image (line 249) | def crop_image(img, size=512, cropping='random', crop_scale=[1, 1]):
class PersonalizedMulti (line 274) | class PersonalizedMulti(Dataset):
method __init__ (line 275) | def __init__(
method setup_templates (line 342) | def setup_templates(self, placeholder_token='sks', coarse_class_text='...
method __len__ (line 359) | def __len__(self):
method __getitem__ (line 362) | def __getitem__(self, i):
class SinImageDataset (line 390) | class SinImageDataset(PersonalizedBase):
method __init__ (line 391) | def __init__(
class SinImageHighResDataset (line 439) | class SinImageHighResDataset(Dataset):
method __init__ (line 440) | def __init__(self,
method __len__ (line 497) | def __len__(self):
method _random_crop (line 500) | def _random_crop(self, pil_image):
method __getitem__ (line 511) | def __getitem__(self, i):
FILE: ldm/data/personalized_painting.py
class PersonalizedBase (line 54) | class PersonalizedBase(Dataset):
method __init__ (line 55) | def __init__(self,
method __len__ (line 104) | def __len__(self):
method __getitem__ (line 107) | def __getitem__(self, i):
class SinImageDataset (line 151) | class SinImageDataset(PersonalizedBase):
method __init__ (line 152) | def __init__(
class SinImageHighResDataset (line 204) | class SinImageHighResDataset(Dataset):
method __init__ (line 205) | def __init__(self,
method __len__ (line 265) | def __len__(self):
method _random_crop (line 268) | def _random_crop(self, pil_image):
method __getitem__ (line 281) | def __getitem__(self, i):
FILE: ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: ldm/modules/attention.py
function exists (line 11) | def exists(val):
function uniq (line 15) | def uniq(arr):
function default (line 19) | def default(val, d):
function max_neg_value (line 25) | def max_neg_value(t):
function init_ (line 29) | def init_(tensor):
class GEGLU (line 37) | class GEGLU(nn.Module):
method __init__ (line 38) | def __init__(self, dim_in, dim_out):
method forward (line 42) | def forward(self, x):
class FeedForward (line 47) | class FeedForward(nn.Module):
method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 63) | def forward(self, x):
function zero_module (line 67) | def zero_module(module):
function Normalize (line 76) | def Normalize(in_channels):
class LinearAttention (line 80) | class LinearAttention(nn.Module):
method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 88) | def forward(self, x):
class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels):
method forward (line 126) | def forward(self, x):
class CrossAttention (line 152) | class CrossAttention(nn.Module):
method __init__ (line 153) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 170) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 196) | class BasicTransformerBlock(nn.Module):
method __init__ (line 197) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 208) | def forward(self, x, context=None):
method _forward (line 211) | def _forward(self, x, context=None):
class SpatialTransformer (line 218) | class SpatialTransformer(nn.Module):
method __init__ (line 226) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 250) | def forward(self, x, context=None):
FILE: ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 12) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 33) | def nonlinearity(x):
function Normalize (line 38) | def Normalize(in_channels, num_groups=32):
class Upsample (line 42) | class Upsample(nn.Module):
method __init__ (line 43) | def __init__(self, in_channels, with_conv):
method forward (line 53) | def forward(self, x):
class Downsample (line 60) | class Downsample(nn.Module):
method __init__ (line 61) | def __init__(self, in_channels, with_conv):
method forward (line 72) | def forward(self, x):
class ResnetBlock (line 82) | class ResnetBlock(nn.Module):
method __init__ (line 83) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 121) | def forward(self, x, temb):
class LinAttnBlock (line 144) | class LinAttnBlock(LinearAttention):
method __init__ (line 146) | def __init__(self, in_channels):
class AttnBlock (line 150) | class AttnBlock(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels):
method forward (line 178) | def forward(self, x):
function make_attn (line 205) | def make_attn(in_channels, attn_type="vanilla"):
class Model (line 216) | class Model(nn.Module):
method __init__ (line 217) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 316) | def forward(self, x, t=None, context=None):
method get_last_layer (line 364) | def get_last_layer(self):
class Encoder (line 368) | class Encoder(nn.Module):
method __init__ (line 369) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 434) | def forward(self, x):
class Decoder (line 462) | class Decoder(nn.Module):
method __init__ (line 463) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 535) | def forward(self, z):
class SimpleDecoder (line 571) | class SimpleDecoder(nn.Module):
method __init__ (line 572) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 594) | def forward(self, x):
class UpsampleDecoder (line 607) | class UpsampleDecoder(nn.Module):
method __init__ (line 608) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 641) | def forward(self, x):
class LatentRescaler (line 655) | class LatentRescaler(nn.Module):
method __init__ (line 656) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 680) | def forward(self, x):
class MergedRescaleEncoder (line 692) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 693) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 705) | def forward(self, x):
class MergedRescaleDecoder (line 711) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 712) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 722) | def forward(self, x):
class Upsampler (line 728) | class Upsampler(nn.Module):
method __init__ (line 729) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 741) | def forward(self, x):
class Resize (line 747) | class Resize(nn.Module):
method __init__ (line 748) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 763) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 770) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 772) | def __init__(self, ch_mult:list, in_channels,
method instantiate_pretrained (line 807) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 816) | def encode_with_pretrained(self,x):
method forward (line 822) | def forward(self,x):
FILE: ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 30) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 34) | def convert_module_to_f32(x):
class AttentionPool2d (line 39) | class AttentionPool2d(nn.Module):
method __init__ (line 44) | def __init__(
method forward (line 59) | def forward(self, x):
class TimestepBlock (line 70) | class TimestepBlock(nn.Module):
method forward (line 76) | def forward(self, x, emb):
class TimestepEmbedSequential (line 82) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 88) | def forward(self, x, emb, context=None):
class Upsample (line 99) | class Upsample(nn.Module):
method __init__ (line 108) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 118) | def forward(self, x):
class TransposedUpsample (line 131) | class TransposedUpsample(nn.Module):
method __init__ (line 134) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 142) | def forward(self, x):
class Downsample (line 146) | class Downsample(nn.Module):
method __init__ (line 155) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 170) | def forward(self, x):
class ResBlock (line 175) | class ResBlock(TimestepBlock):
method __init__ (line 191) | def __init__(
method forward (line 257) | def forward(self, x, emb):
method _forward (line 268) | def _forward(self, x, emb):
class AttentionBlock (line 291) | class AttentionBlock(nn.Module):
method __init__ (line 298) | def __init__(
method forward (line 327) | def forward(self, x):
method _forward (line 332) | def _forward(self, x):
function count_flops_attn (line 341) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 361) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 366) | def __init__(self, n_heads):
method forward (line 370) | def forward(self, qkv):
method count_flops (line 390) | def count_flops(model, _x, y):
class QKVAttention (line 394) | class QKVAttention(nn.Module):
method __init__ (line 399) | def __init__(self, n_heads):
method forward (line 403) | def forward(self, qkv):
method count_flops (line 425) | def count_flops(model, _x, y):
class UNetModel (line 429) | class UNetModel(nn.Module):
method __init__ (line 459) | def __init__(
method convert_to_fp16 (line 712) | def convert_to_fp16(self):
method convert_to_fp32 (line 720) | def convert_to_fp32(self):
method forward (line 728) | def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
class UNetModelPatch (line 764) | class UNetModelPatch(UNetModel):
method __init__ (line 765) | def __init__(
method forward (line 832) | def forward(self, x, timesteps=None, context=None, y=None, crop_boxes=...
class EncoderUNetModel (line 894) | class EncoderUNetModel(nn.Module):
method __init__ (line 900) | def __init__(
method convert_to_fp16 (line 1073) | def convert_to_fp16(self):
method convert_to_fp32 (line 1080) | def convert_to_fp32(self):
method forward (line 1087) | def forward(self, x, timesteps):
FILE: ldm/modules/diffusionmodules/positional_encoding.py
class SinusoidalPositionalEmbedding (line 15) | class SinusoidalPositionalEmbedding(nn.Module):
method __init__ (line 39) | def __init__(self,
method get_embedding (line 59) | def get_embedding(num_embeddings,
method forward (line 89) | def forward(self, input, **kwargs):
method make_positions (line 115) | def make_positions(self, input, padding_idx):
method make_grid2d (line 120) | def make_grid2d(self, height, width, num_batches=1, center_shift=None):
method make_grid2d_like (line 168) | def make_grid2d_like(self, x, center_shift=None):
class CatersianGrid (line 182) | class CatersianGrid(nn.Module):
method forward (line 191) | def forward(self, x, **kwargs):
method make_grid2d (line 195) | def make_grid2d(self, height, width, num_batches=1, requires_grad=False):
method make_grid2d_like (line 208) | def make_grid2d_like(self, x, requires_grad=False):
FILE: ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 131) | def backward(ctx, *output_grads):
function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 174) | def zero_module(module):
function scale_module (line 183) | def scale_module(module, scale):
function mean_flat (line 192) | def mean_flat(tensor):
function normalization (line 199) | def normalization(channels):
class SiLU (line 209) | class SiLU(nn.Module):
method forward (line 210) | def forward(self, x):
class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
method forward (line 215) | def forward(self, x):
function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
function linear (line 231) | def linear(*args, **kwargs):
function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 258) | def forward(self, c_concat, c_crossattn):
function noise_like (line 264) | def noise_like(shape, device, repeat=False):
FILE: ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: ldm/modules/embedding_manager.py
function get_clip_token_for_string (line 12) | def get_clip_token_for_string(tokenizer, string):
function get_bert_token_for_string (line 20) | def get_bert_token_for_string(tokenizer, string):
function get_embedding_for_clip_token (line 28) | def get_embedding_for_clip_token(embedder, token):
class EmbeddingManager (line 32) | class EmbeddingManager(nn.Module):
method __init__ (line 33) | def __init__(
method forward (line 88) | def forward(
method save (line 131) | def save(self, ckpt_path):
method load (line 135) | def load(self, ckpt_path):
method get_embedding_norms_squared (line 141) | def get_embedding_norms_squared(self):
method embedding_parameters (line 147) | def embedding_parameters(self):
method embedding_to_coarse_loss (line 150) | def embedding_to_coarse_loss(self):
FILE: ldm/modules/encoders/modules.py
function _expand_mask (line 11) | def _expand_mask(mask, dtype, tgt_len = None):
function _build_causal_attention_mask (line 24) | def _build_causal_attention_mask(bsz, seq_len, dtype):
class AbstractEncoder (line 33) | class AbstractEncoder(nn.Module):
method __init__ (line 34) | def __init__(self):
method encode (line 37) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 42) | class ClassEmbedder(nn.Module):
method __init__ (line 43) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 48) | def forward(self, batch, key=None):
class TransformerEmbedder (line 57) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 59) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 65) | def forward(self, tokens):
method encode (line 70) | def encode(self, x):
class BERTTokenizer (line 74) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 76) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 84) | def forward(self, text):
method encode (line 91) | def encode(self, text):
method decode (line 97) | def decode(self, text):
class BERTEmbedder (line 101) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 103) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 114) | def forward(self, text, embedding_manager=None):
method encode (line 122) | def encode(self, text, **kwargs):
class SpatialRescaler (line 126) | class SpatialRescaler(nn.Module):
method __init__ (line 127) | def __init__(self,
method forward (line 145) | def forward(self,x):
method encode (line 154) | def encode(self, x):
class FrozenCLIPEmbedder (line 157) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 159) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 310) | def freeze(self):
method forward (line 315) | def forward(self, text, **kwargs):
method encode (line 324) | def encode(self, text, **kwargs):
class FrozenCLIPTextEmbedder (line 328) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 332) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 340) | def freeze(self):
method forward (line 345) | def forward(self, text):
method encode (line 352) | def encode(self, text):
class FrozenClipImageEmbedder (line 360) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 364) | def __init__(
method preprocess (line 379) | def preprocess(self, x):
method forward (line 389) | def forward(self, x):
FILE: ldm/modules/encoders/modules_bak.py
function _expand_mask (line 11) | def _expand_mask(mask, dtype, tgt_len = None):
function _build_causal_attention_mask (line 24) | def _build_causal_attention_mask(bsz, seq_len, dtype):
class AbstractEncoder (line 33) | class AbstractEncoder(nn.Module):
method __init__ (line 34) | def __init__(self):
method encode (line 37) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 42) | class ClassEmbedder(nn.Module):
method __init__ (line 43) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 48) | def forward(self, batch, key=None):
class TransformerEmbedder (line 57) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 59) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 65) | def forward(self, tokens):
method encode (line 70) | def encode(self, x):
class BERTTokenizer (line 74) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 76) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 84) | def forward(self, text):
method encode (line 91) | def encode(self, text):
method decode (line 97) | def decode(self, text):
class BERTEmbedder (line 101) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 103) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 114) | def forward(self, text, embedding_manager=None):
method encode (line 122) | def encode(self, text, **kwargs):
class SpatialRescaler (line 126) | class SpatialRescaler(nn.Module):
method __init__ (line 127) | def __init__(self,
method forward (line 145) | def forward(self,x):
method encode (line 154) | def encode(self, x):
class FrozenCLIPEmbedder (line 157) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 159) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 410) | def freeze(self):
method forward (line 415) | def forward(self, text, **kwargs):
method encode (line 423) | def encode(self, text, **kwargs):
class FrozenCLIPTextEmbedder (line 427) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 431) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 439) | def freeze(self):
method forward (line 444) | def forward(self, text):
method encode (line 451) | def encode(self, text):
class FrozenClipImageEmbedder (line 459) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 463) | def __init__(
method preprocess (line 478) | def preprocess(self, x):
method forward (line 488) | def forward(self, x):
FILE: ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 343) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 373) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 390) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 408) | def add_Poisson_noise(img):
function add_JPEG_noise (line 422) | def add_JPEG_noise(img):
function random_crop (line 431) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 442) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 534) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
FILE: ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: ldm/modules/losses/contperceptual.py
class LPIPSWithDiscriminator (line 7) | class LPIPSWithDiscriminator(nn.Module):
method __init__ (line 8) | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixello...
method calculate_adaptive_weight (line 32) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 45) | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
FILE: ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: ldm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 25) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 26) | def __init__(self, dim, max_seq_len):
method init_ (line 31) | def init_(self):
method forward (line 34) | def forward(self, x):
class FixedPositionalEmbedding (line 39) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 40) | def __init__(self, dim):
method forward (line 45) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 54) | def exists(val):
function default (line 58) | def default(val, d):
function always (line 64) | def always(val):
function not_equals (line 70) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 82) | def max_neg_value(tensor):
function pick_and_pop (line 88) | def pick_and_pop(keys, d):
function group_dict_by_key (line 93) | def group_dict_by_key(cond, d):
function string_begins_with (line 102) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 106) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 110) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 117) | class Scale(nn.Module):
method __init__ (line 118) | def __init__(self, value, fn):
method forward (line 123) | def forward(self, x, **kwargs):
class Rezero (line 128) | class Rezero(nn.Module):
method __init__ (line 129) | def __init__(self, fn):
method forward (line 134) | def forward(self, x, **kwargs):
class ScaleNorm (line 139) | class ScaleNorm(nn.Module):
method __init__ (line 140) | def __init__(self, dim, eps=1e-5):
method forward (line 146) | def forward(self, x):
class RMSNorm (line 151) | class RMSNorm(nn.Module):
method __init__ (line 152) | def __init__(self, dim, eps=1e-8):
method forward (line 158) | def forward(self, x):
class Residual (line 163) | class Residual(nn.Module):
method forward (line 164) | def forward(self, x, residual):
class GRUGating (line 168) | class GRUGating(nn.Module):
method __init__ (line 169) | def __init__(self, dim):
method forward (line 173) | def forward(self, x, residual):
class GEGLU (line 184) | class GEGLU(nn.Module):
method __init__ (line 185) | def __init__(self, dim_in, dim_out):
method forward (line 189) | def forward(self, x):
class FeedForward (line 194) | class FeedForward(nn.Module):
method __init__ (line 195) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 210) | def forward(self, x):
class Attention (line 215) | class Attention(nn.Module):
method __init__ (line 216) | def __init__(
method forward (line 268) | def forward(
class AttentionLayers (line 370) | class AttentionLayers(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 481) | def forward(
class Encoder (line 542) | class Encoder(AttentionLayers):
method __init__ (line 543) | def __init__(self, **kwargs):
class TransformerWrapper (line 549) | class TransformerWrapper(nn.Module):
method __init__ (line 550) | def __init__(
method init_ (line 596) | def init_(self):
method forward (line 599) | def forward(
FILE: ldm/util.py
function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 41) | def ismap(x):
function isimage (line 47) | def isimage(x):
function exists (line 53) | def exists(x):
function default (line 57) | def default(val, d):
function mean_flat (line 63) | def mean_flat(tensor):
function count_params (line 71) | def count_params(model, verbose=False):
function instantiate_from_config (line 78) | def instantiate_from_config(config, **kwargs):
function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 108) | def parallel_data_prefetch(
FILE: main.py
function load_model_from_config (line 24) | def load_model_from_config(config, ckpt, verbose=False):
function get_parser (line 41) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 181) | def nondefault_trainer_args(opt):
class WrappedDataset (line 188) | class WrappedDataset(Dataset):
method __init__ (line 191) | def __init__(self, dataset):
method __len__ (line 194) | def __len__(self):
method __getitem__ (line 197) | def __getitem__(self, idx):
function worker_init_fn (line 201) | def worker_init_fn(_):
class ConcatDataset (line 216) | class ConcatDataset(Dataset):
method __init__ (line 217) | def __init__(self, *datasets):
method __getitem__ (line 220) | def __getitem__(self, idx):
method __len__ (line 223) | def __len__(self):
class DataModuleFromConfig (line 226) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 227) | def __init__(self, batch_size, train=None, reg = None, validation=None...
method prepare_data (line 253) | def prepare_data(self):
method setup (line 257) | def setup(self, stage=None):
method _train_dataloader (line 265) | def _train_dataloader(self):
method _val_dataloader (line 278) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 289) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 302) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 311) | class SetupCallback(Callback):
method __init__ (line 312) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 322) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_pretrain_routine_start (line 328) | def on_pretrain_routine_start(self, trainer, pl_module):
class ImageLogger (line 360) | class ImageLogger(Callback):
method __init__ (line 361) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _testtube (line 381) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 392) | def log_local(self, save_dir, split, images,
method log_img (line 411) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 443) | def check_frequency(self, check_idx):
method on_train_batch_end (line 454) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 458) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 466) | class CUDACallback(Callback):
method on_train_epoch_start (line 468) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 474) | def on_train_epoch_end(self, trainer, pl_module):
class ModeSwapCallback (line 488) | class ModeSwapCallback(Callback):
method __init__ (line 490) | def __init__(self, swap_step=2000):
method on_train_epoch_start (line 495) | def on_train_epoch_start(self, trainer, pl_module):
function melk (line 768) | def melk(*args, **kwargs):
function divein (line 776) | def divein(*args, **kwargs):
FILE: scripts/sample_diffusion.py
function custom_to_pil (line 15) | def custom_to_pil(x):
function custom_to_np (line 27) | def custom_to_np(x):
function logs2pil (line 36) | def logs2pil(logs, keys=["sample"]):
function convsample (line 54) | def convsample(model, shape, return_intermediates=True,
function convsample_ddim (line 69) | def convsample_ddim(model, steps, shape, eta=1.0
function make_convolutional_sample (line 79) | def make_convolutional_sample(model, batch_size, vanilla=False, custom_s...
function run (line 108) | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, ...
function save_logs (line 143) | def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
function get_parser (line 162) | def get_parser():
function load_model_from_config (line 220) | def load_model_from_config(config, sd):
function load_model (line 228) | def load_model(config, ckpt, gpu, eval_mode):
FILE: scripts/score.py
function _transform (line 26) | def _transform():
function _convert_image_to_rgb (line 35) | def _convert_image_to_rgb(image):
FILE: scripts/stable_txt2img_guidance.py
function chunk (line 20) | def chunk(it, size):
function load_model_from_config (line 25) | def load_model_from_config(config, ckpt, verbose=False):
function main (line 45) | def main():
FILE: scripts/stable_txt2img_multi_guidance.py
function chunk (line 20) | def chunk(it, size):
function load_model_from_config (line 25) | def load_model_from_config(config, ckpt, verbose=False):
function main (line 45) | def main():
Condensed preview — 51 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (513K chars).
[
{
"path": ".gitignore",
"chars": 2512,
"preview": "# General\n.DS_Store\n.AppleDouble\n.LSOverride\nmodels\n# models/\nlogs\nlogs/\nexps/\ndatasets/\nsrc/\n# Icon must end with two \\"
},
{
"path": "LICENSE",
"chars": 1125,
"preview": "MIT License\n\nCopyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors\n\nPermission is here"
},
{
"path": "README.md",
"chars": 5945,
"preview": "## SINE <br><sub> <ins>SIN</ins>gle Image <ins>E</ins>diting with Text-to-Image Diffusion Models</sub>\n\n[:\n raise NotImplementedError()\n"
},
{
"path": "ldm/modules/ema.py",
"chars": 2982,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ldm/modules/embedding_manager.py",
"chars": 6630,
"preview": "import torch\nfrom torch import nn\n\nfrom ldm.data.personalized import per_img_token_list\nfrom transformers import CLIPTok"
},
{
"path": "ldm/modules/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/encoders/modules.py",
"chars": 14841,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom t"
},
{
"path": "ldm/modules/encoders/modules_bak.py",
"chars": 18883,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom t"
},
{
"path": "ldm/modules/image_degradation/__init__.py",
"chars": 208,
"preview": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image"
},
{
"path": "ldm/modules/image_degradation/bsrgan.py",
"chars": 25198,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
},
{
"path": "ldm/modules/image_degradation/bsrgan_light.py",
"chars": 22238,
"preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
},
{
"path": "ldm/modules/image_degradation/utils_image.py",
"chars": 29022,
"preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
},
{
"path": "ldm/modules/losses/__init__.py",
"chars": 68,
"preview": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
},
{
"path": "ldm/modules/losses/contperceptual.py",
"chars": 5581,
"preview": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?\n"
},
{
"path": "ldm/modules/losses/vqperceptual.py",
"chars": 7941,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
},
{
"path": "ldm/modules/x_transformer.py",
"chars": 20369,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import "
},
{
"path": "ldm/util.py",
"chars": 5877,
"preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
},
{
"path": "main.py",
"chars": 30789,
"preview": "import argparse, os, sys, datetime, glob, importlib, csv\nimport numpy as np\nimport time\nimport torch\n\nimport torchvision"
},
{
"path": "scripts/download_first_stages.sh",
"chars": 1324,
"preview": "#!/bin/bash\nwget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip\nwge"
},
{
"path": "scripts/download_models.sh",
"chars": 1681,
"preview": "#!/bin/bash\nwget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip\nwget -O "
},
{
"path": "scripts/sample_diffusion.py",
"chars": 9606,
"preview": "import argparse, os, sys, glob, datetime, yaml\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import trange\n\nfrom"
},
{
"path": "scripts/score.py",
"chars": 3219,
"preview": "import clip\nimport torch\nimport os\nfrom PIL import Image\nimport lpips\nfrom torchvision.transforms import Compose, Resize"
},
{
"path": "scripts/stable_txt2img_guidance.py",
"chars": 11432,
"preview": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfro"
},
{
"path": "scripts/stable_txt2img_multi_guidance.py",
"chars": 12321,
"preview": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfro"
},
{
"path": "setup.py",
"chars": 233,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name='latent-diffusion',\n version='0.0.1',\n description=''"
}
]
About this extraction
This page contains the full source code of the zhang-zx/SINE GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 51 files (480.7 KB), approximately 121.2k tokens, and a symbol index with 578 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.