Showing preview only (4,743K chars total). Download the full file or copy to clipboard to get everything.
Repository: XavierXiao/Dreambooth-Stable-Diffusion
Branch: main
Commit: 12fa9509f644
Files: 86
Total size: 4.5 MB
Directory structure:
gitextract_pv77kbwr/
├── LICENSE
├── README.md
├── configs/
│ ├── autoencoder/
│ │ ├── autoencoder_kl_16x16x16.yaml
│ │ ├── autoencoder_kl_32x32x4.yaml
│ │ ├── autoencoder_kl_64x64x3.yaml
│ │ └── autoencoder_kl_8x8x64.yaml
│ ├── latent-diffusion/
│ │ ├── celebahq-ldm-vq-4.yaml
│ │ ├── cin-ldm-vq-f8.yaml
│ │ ├── cin256-v2.yaml
│ │ ├── ffhq-ldm-vq-4.yaml
│ │ ├── lsun_bedrooms-ldm-vq-4.yaml
│ │ ├── lsun_churches-ldm-kl-8.yaml
│ │ ├── txt2img-1p4B-eval.yaml
│ │ ├── txt2img-1p4B-eval_with_tokens.yaml
│ │ ├── txt2img-1p4B-finetune.yaml
│ │ └── txt2img-1p4B-finetune_style.yaml
│ └── stable-diffusion/
│ ├── v1-finetune.yaml
│ ├── v1-finetune_unfrozen.yaml
│ └── v1-inference.yaml
├── environment.yaml
├── evaluation/
│ └── clip_eval.py
├── ldm/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── imagenet.py
│ │ ├── lsun.py
│ │ ├── personalized.py
│ │ └── personalized_style.py
│ ├── lr_scheduler.py
│ ├── models/
│ │ ├── autoencoder.py
│ │ └── diffusion/
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ └── plms.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── distributions/
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── embedding_manager.py
│ │ ├── encoders/
│ │ │ ├── __init__.py
│ │ │ ├── modules.py
│ │ │ └── modules_bak.py
│ │ ├── image_degradation/
│ │ │ ├── __init__.py
│ │ │ ├── bsrgan.py
│ │ │ ├── bsrgan_light.py
│ │ │ └── utils_image.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ └── x_transformer.py
│ └── util.py
├── main.py
├── merge_embeddings.py
├── models/
│ ├── first_stage_models/
│ │ ├── kl-f16/
│ │ │ └── config.yaml
│ │ ├── kl-f32/
│ │ │ └── config.yaml
│ │ ├── kl-f4/
│ │ │ └── config.yaml
│ │ ├── kl-f8/
│ │ │ └── config.yaml
│ │ ├── vq-f16/
│ │ │ └── config.yaml
│ │ ├── vq-f4/
│ │ │ └── config.yaml
│ │ ├── vq-f4-noattn/
│ │ │ └── config.yaml
│ │ ├── vq-f8/
│ │ │ └── config.yaml
│ │ └── vq-f8-n256/
│ │ └── config.yaml
│ └── ldm/
│ ├── bsr_sr/
│ │ └── config.yaml
│ ├── celeba256/
│ │ └── config.yaml
│ ├── cin256/
│ │ └── config.yaml
│ ├── ffhq256/
│ │ └── config.yaml
│ ├── inpainting_big/
│ │ └── config.yaml
│ ├── layout2img-openimages256/
│ │ └── config.yaml
│ ├── lsun_beds256/
│ │ └── config.yaml
│ ├── lsun_churches256/
│ │ └── config.yaml
│ ├── semantic_synthesis256/
│ │ └── config.yaml
│ ├── semantic_synthesis512/
│ │ └── config.yaml
│ └── text2img256/
│ └── config.yaml
├── scripts/
│ ├── download_first_stages.sh
│ ├── download_models.sh
│ ├── evaluate_model.py
│ ├── inpaint.py
│ ├── latent_imagenet_diffusion.ipynb
│ ├── sample_diffusion.py
│ ├── stable_txt2img.py
│ └── txt2img.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Dreambooth on Stable Diffusion
This is an implementtaion of Google's [Dreambooth](https://arxiv.org/abs/2208.12242) with [Stable Diffusion](https://github.com/CompVis/stable-diffusion). The original Dreambooth is based on [Imagen](https://imagen.research.google/) text-to-image model. However, neither the model nor the pre-trained weights of Imagen is available. To enable people to fine-tune a text-to-image model with a few examples, I implemented the idea of Dreambooth on Stable diffusion.
This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word ebedding, while dreambooth fine-tunes the whole diffusion model.
The implementation makes minimum changes over the official codebase of Textual Inversion. In fact, due to lazyness, some components in Textual Inversion, such as the embedding manager, are not deleted, although they will never be used here.
## Update
**9/20/2022**: I just found a way to reduce the GPU memory a bit. Remember that this code is based on Textual Inversion, and TI's code base has [this line](https://github.com/rinongal/textual_inversion/blob/main/ldm/modules/diffusionmodules/util.py#L112), which disable gradient checkpointing in a hard-code way. This is because in TI, the Unet is not optimized. However, in Dreambooth we optimize the Unet, so we can turn on the gradient checkpoint pointing trick, as in the original SD repo [here](https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L112). The gradient checkpoint is default to be True in [config](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/configs/stable-diffusion/v1-finetune_unfrozen.yaml#L47). I have updated the codes.
## Usage
### Preparation
First set-up the ```ldm``` enviroment following the instruction from textual inversion repo, or the original Stable Diffusion repo.
To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their [instructions](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1). Weights can be downloaded on [HuggingFace](https://huggingface.co/CompVis). You can decide which version of checkpoint to use, but I use ```sd-v1-4-full-ema.ckpt```.
We also need to create a set of images for regularization, as the fine-tuning algorithm of Dreambooth requires that. Details of the algorithm can be found in the paper. Note that in the original paper, the regularization images seem to be generated on-the-fly. However, here I generated a set of regularization images before the training. The text prompt for generating regularization images can be ```photo of a <class>```, where ```<class>``` is a word that describes the class of your object, such as ```dog```. The command is
```
python scripts/stable_txt2img.py --ddim_eta 0.0 --n_samples 8 --n_iter 1 --scale 10.0 --ddim_steps 50 --ckpt /path/to/original/stable-diffusion/sd-v1-4-full-ema.ckpt --prompt "a photo of a <class>"
```
I generate 8 images for regularization, but more regularization images may lead to stronger regularization and better editability. After that, save the generated images (separately, one image per ```.png``` file) at ```/root/to/regularization/images```.
**Updates on 9/9**
We should definitely use more images for regularization. Please try 100 or 200, to better align with the original paper. To acomodate this, I shorten the "repeat" of reg dataset in the [config file](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/configs/stable-diffusion/v1-finetune_unfrozen.yaml#L96).
For some cases, if the generated regularization images are highly unrealistic (happens when you want to generate "man" or "woman"), you can find a diverse set of images (of man/woman) online, and use them as regularization images.
### Training
Training can be done by running the following command
```
python main.py --base configs/stable-diffusion/v1-finetune_unfrozen.yaml
-t
--actual_resume /path/to/original/stable-diffusion/sd-v1-4-full-ema.ckpt
-n <job name>
--gpus 0,
--data_root /root/to/training/images
--reg_data_root /root/to/regularization/images
--class_word <xxx>
```
Detailed configuration can be found in ```configs/stable-diffusion/v1-finetune_unfrozen.yaml```. In particular, the default learning rate is ```1.0e-6``` as I found the ```1.0e-5``` in the Dreambooth paper leads to poor editability. The parameter ```reg_weight``` corresponds to the weight of regularization in the Dreambooth paper, and the default is set to ```1.0```.
Dreambooth requires a placeholder word ```[V]```, called identifier, as in the paper. This identifier needs to be a relatively rare tokens in the vocabulary. The original paper approaches this by using a rare word in T5-XXL tokenizer. For simplicity, here I just use a random word ```sks``` and hard coded it.. If you want to change that, simply make a change in [this file](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10).
Training will be run for 800 steps, and two checkpoints will be saved at ```./logs/<job_name>/checkpoints```, one at 500 steps and one at final step. Typically the one at 500 steps works well enough. I train the model use two A6000 GPUs and it takes ~15 mins.
### Generation
After training, personalized samples can be obtained by running the command
```
python scripts/stable_txt2img.py --ddim_eta 0.0
--n_samples 8
--n_iter 1
--scale 10.0
--ddim_steps 100
--ckpt /path/to/saved/checkpoint/from/training
--prompt "photo of a sks <class>"
```
In particular, ```sks``` is the identifier, which should be replaced by your choice if you happen to change the identifier, and ```<class>``` is the class word ```--class_word``` for training.
## Results
Here I show some qualitative results. The training images are obtained from the [issue](https://github.com/rinongal/textual_inversion/issues/8) in the Textual Inversion repository, and they are 3 images of a large trash container. Regularization images are generated by prompt ```photo of a container```. Regularization images are shown here:

After training, generated images with prompt ```photo of a sks container```:

Generated images with prompt ```photo of a sks container on the beach```:

Generated images with prompt ```photo of a sks container on the moon```:

Some not-so-perfect but still interesting results:
Generated images with prompt ```photo of a red sks container```:

Generated images with prompt ```a dog on top of sks container```:

================================================
FILE: configs/autoencoder/autoencoder_kl_16x16x16.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 16
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 16
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
================================================
FILE: configs/autoencoder/autoencoder_kl_32x32x4.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 4
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
================================================
FILE: configs/autoencoder/autoencoder_kl_64x64x3.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 3
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
================================================
FILE: configs/autoencoder/autoencoder_kl_8x8x64.yaml
================================================
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 64
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 64
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16,8]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
================================================
FILE: configs/latent-diffusion/celebahq-ldm-vq-4.yaml
================================================
model:
base_learning_rate: 2.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
image_size: 64
channels: 3
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 3
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__
data:
target: main.DataModuleFromConfig
params:
batch_size: 48
num_workers: 5
wrap: false
train:
target: taming.data.faceshq.CelebAHQTrain
params:
size: 256
validation:
target: taming.data.faceshq.CelebAHQValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/latent-diffusion/cin-ldm-vq-f8.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: class_label
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 256
attention_resolutions:
#note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 32 for f8
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
num_head_channels: 32
use_spatial_transformer: true
transformer_depth: 1
context_dim: 512
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 4
n_embed: 16384
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
ddconfig:
double_z: false
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 2
- 4
num_res_blocks: 2
attn_resolutions:
- 32
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.ClassEmbedder
params:
embed_dim: 512
key: class_label
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
num_workers: 12
wrap: false
train:
target: ldm.data.imagenet.ImageNetTrain
params:
config:
size: 256
validation:
target: ldm.data.imagenet.ImageNetValidation
params:
config:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/latent-diffusion/cin256-v2.yaml
================================================
model:
base_learning_rate: 0.0001
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: class_label
image_size: 64
channels: 3
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 3
out_channels: 3
model_channels: 192
attention_resolutions:
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 5
num_heads: 1
use_spatial_transformer: true
transformer_depth: 1
context_dim: 512
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.ClassEmbedder
params:
n_classes: 1001
embed_dim: 512
key: class_label
================================================
FILE: configs/latent-diffusion/ffhq-ldm-vq-4.yaml
================================================
model:
base_learning_rate: 2.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
image_size: 64
channels: 3
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 3
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
embed_dim: 3
n_embed: 8192
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__
data:
target: main.DataModuleFromConfig
params:
batch_size: 42
num_workers: 5
wrap: false
train:
target: taming.data.faceshq.FFHQTrain
params:
size: 256
validation:
target: taming.data.faceshq.FFHQValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml
================================================
model:
base_learning_rate: 2.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
image_size: 64
channels: 3
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 64
in_channels: 3
out_channels: 3
model_channels: 224
attention_resolutions:
# note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 64 for f4
- 8
- 4
- 2
num_res_blocks: 2
channel_mult:
- 1
- 2
- 3
- 4
num_head_channels: 32
first_stage_config:
target: ldm.models.autoencoder.VQModelInterface
params:
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
embed_dim: 3
n_embed: 8192
ddconfig:
double_z: false
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_unconditional__
data:
target: main.DataModuleFromConfig
params:
batch_size: 48
num_workers: 5
wrap: false
train:
target: ldm.data.lsun.LSUNBedroomsTrain
params:
size: 256
validation:
target: ldm.data.lsun.LSUNBedroomsValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml
================================================
model:
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.0015
linear_end: 0.0155
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
loss_type: l1
first_stage_key: "image"
cond_stage_key: "image"
image_size: 32
channels: 4
cond_stage_trainable: False
concat_mode: False
scale_by_std: True
monitor: 'val/loss_simple_ema'
scheduler_config: # 10000 warmup steps
target: ldm.lr_scheduler.LambdaLinearScheduler
params:
warm_up_steps: [10000]
cycle_lengths: [10000000000000]
f_start: [1.e-6]
f_max: [1.]
f_min: [ 1.]
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 192
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
num_res_blocks: 2
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
num_heads: 8
use_scale_shift_norm: True
resblock_updown: True
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: "val/rec_loss"
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: "__is_unconditional__"
data:
target: main.DataModuleFromConfig
params:
batch_size: 96
num_workers: 5
wrap: False
train:
target: ldm.data.lsun.LSUNChurchesTrain
params:
size: 256
validation:
target: ldm.data.lsun.LSUNChurchesValidation
params:
size: 256
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 5000
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/latent-diffusion/txt2img-1p4B-eval.yaml
================================================
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_heads: 8
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1280
use_checkpoint: true
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.BERTEmbedder
params:
n_embed: 1280
n_layer: 32
================================================
FILE: configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml
================================================
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: []
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_heads: 8
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1280
use_checkpoint: true
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.BERTEmbedder
params:
n_embed: 1280
n_layer: 32
================================================
FILE: configs/latent-diffusion/txt2img-1p4B-finetune.yaml
================================================
model:
base_learning_rate: 5.0e-3
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_heads: 8
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1280
use_checkpoint: true
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.BERTEmbedder
params:
n_embed: 1280
n_layer: 32
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 256
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 256
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 6100
================================================
FILE: configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml
================================================
model:
base_learning_rate: 5.0e-3
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 32
channels: 4
cond_stage_trainable: true
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["painting"]
per_image_tokens: false
num_vectors_per_token: 1
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
num_heads: 8
use_spatial_transformer: true
transformer_depth: 1
context_dim: 1280
use_checkpoint: true
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.BERTEmbedder
params:
n_embed: 1280
n_layer: 32
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 4
wrap: false
train:
target: ldm.data.personalized_style.PersonalizedBase
params:
size: 256
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized_style.PersonalizedBase
params:
size: 256
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
================================================
FILE: configs/stable-diffusion/v1-finetune.yaml
================================================
model:
base_learning_rate: 5.0e-03
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: False
model_lr: 0.0
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 2
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 6100
================================================
FILE: configs/stable-diffusion/v1-finetune_unfrozen.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
reg_weight: 1.0
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 64
channels: 4
cond_stage_trainable: true # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
embedding_reg_weight: 0.0
unfreeze_model: True
model_lr: 1.0e-6
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 512
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 2
wrap: false
train:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
per_image_tokens: false
repeats: 100
reg:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: train
reg: true
per_image_tokens: false
repeats: 10
validation:
target: ldm.data.personalized.PersonalizedBase
params:
size: 512
set: val
per_image_tokens: false
repeats: 10
lightning:
modelcheckpoint:
params:
every_n_train_steps: 500
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 500
max_images: 8
increase_log_steps: False
trainer:
benchmark: True
max_steps: 800
================================================
FILE: configs/stable-diffusion/v1-inference.yaml
================================================
model:
base_learning_rate: 1.0e-04
target: ldm.models.diffusion.ddpm.LatentDiffusion
params:
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: "jpg"
cond_stage_key: "txt"
image_size: 64
channels: 4
cond_stage_trainable: false # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
personalization_config:
target: ldm.modules.embedding_manager.EmbeddingManager
params:
placeholder_strings: ["*"]
initializer_words: ["sculpture"]
per_image_tokens: false
num_vectors_per_token: 1
progressive_words: False
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_heads: 8
use_spatial_transformer: True
transformer_depth: 1
context_dim: 768
use_checkpoint: True
legacy: False
first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
================================================
FILE: environment.yaml
================================================
name: ldm
channels:
- pytorch
- defaults
dependencies:
- python=3.8.10
- pip=20.3
- cudatoolkit=11.3
- pytorch=1.10.2
- torchvision=0.11.3
- numpy=1.22.3
- pip:
- albumentations==1.1.0
- opencv-python==4.2.0.34
- pudb==2019.2
- imageio==2.14.1
- imageio-ffmpeg==0.4.7
- pytorch-lightning==1.5.9
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
- setuptools==59.5.0
- pillow==9.0.1
- einops==0.4.1
- torch-fidelity==0.3.0
- transformers==4.18.0
- torchmetrics==0.6.0
- kornia==0.6
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
================================================
FILE: evaluation/clip_eval.py
================================================
import clip
import torch
from torchvision import transforms
from ldm.models.diffusion.ddim import DDIMSampler
class CLIPEvaluator(object):
def __init__(self, device, clip_model='ViT-B/32') -> None:
self.device = device
self.model, clip_preprocess = clip.load(clip_model, device=self.device)
self.clip_preprocess = clip_preprocess
self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (generator output) to [0, 1].
clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
def tokenize(self, strings: list):
return clip.tokenize(strings).to(self.device)
@torch.no_grad()
def encode_text(self, tokens: list) -> torch.Tensor:
return self.model.encode_text(tokens)
@torch.no_grad()
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
images = self.preprocess(images).to(self.device)
return self.model.encode_image(images)
def get_text_features(self, text: str, norm: bool = True) -> torch.Tensor:
tokens = clip.tokenize(text).to(self.device)
text_features = self.encode_text(tokens).detach()
if norm:
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features
def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor:
image_features = self.encode_images(img)
if norm:
image_features /= image_features.clone().norm(dim=-1, keepdim=True)
return image_features
def img_to_img_similarity(self, src_images, generated_images):
src_img_features = self.get_image_features(src_images)
gen_img_features = self.get_image_features(generated_images)
return (src_img_features @ gen_img_features.T).mean()
def txt_to_img_similarity(self, text, generated_images):
text_features = self.get_text_features(text)
gen_img_features = self.get_image_features(generated_images)
return (text_features @ gen_img_features.T).mean()
class LDMCLIPEvaluator(CLIPEvaluator):
def __init__(self, device, clip_model='ViT-B/32') -> None:
super().__init__(device, clip_model)
def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n_steps=50):
sampler = DDIMSampler(ldm_model)
samples_per_batch = 8
n_batches = n_samples // samples_per_batch
# generate samples
all_samples=list()
with torch.no_grad():
with ldm_model.ema_scope():
uc = ldm_model.get_learned_conditioning(samples_per_batch * [""])
for batch in range(n_batches):
c = ldm_model.get_learned_conditioning(samples_per_batch * [target_text])
shape = [4, 256//8, 256//8]
samples_ddim, _ = sampler.sample(S=n_steps,
conditioning=c,
batch_size=samples_per_batch,
shape=shape,
verbose=False,
unconditional_guidance_scale=5.0,
unconditional_conditioning=uc,
eta=0.0)
x_samples_ddim = ldm_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp(x_samples_ddim, min=-1.0, max=1.0)
all_samples.append(x_samples_ddim)
all_samples = torch.cat(all_samples, axis=0)
sim_samples_to_img = self.img_to_img_similarity(src_images, all_samples)
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), all_samples)
return sim_samples_to_img, sim_samples_to_text
class ImageDirEvaluator(CLIPEvaluator):
def __init__(self, device, clip_model='ViT-B/32') -> None:
super().__init__(device, clip_model)
def evaluate(self, gen_samples, src_images, target_text):
sim_samples_to_img = self.img_to_img_similarity(src_images, gen_samples)
sim_samples_to_text = self.txt_to_img_similarity(target_text.replace("*", ""), gen_samples)
return sim_samples_to_img, sim_samples_to_text
================================================
FILE: ldm/data/__init__.py
================================================
================================================
FILE: ldm/data/base.py
================================================
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass
================================================
FILE: ldm/data/imagenet.py
================================================
import os, yaml, pickle, shutil, tarfile, glob
import cv2
import albumentations
import PIL
import numpy as np
import torchvision.transforms.functional as TF
from omegaconf import OmegaConf
from functools import partial
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, Subset
import taming.data.utils as tdu
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
from taming.data.imagenet import ImagePaths
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
def synset2idx(path_to_yaml="data/index_synset.yaml"):
with open(path_to_yaml) as f:
di2s = yaml.load(f)
return dict((v,k) for k,v in di2s.items())
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
self.config = OmegaConf.to_container(self.config)
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._prepare_human_to_integer_label()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
download(URL, self.idx2syn)
def _prepare_human_to_integer_label(self):
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
if (not os.path.exists(self.human2integer)):
download(URL, self.human2integer)
with open(self.human2integer, "r") as f:
lines = f.read().splitlines()
assert len(lines) == 1000
self.human2integer_dict = dict()
for line in lines:
value, key = line.split(":")
self.human2integer_dict[key] = int(value)
def _load(self):
with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
if not self.keep_orig_class_label:
self.class_labels = [class_dict[s] for s in self.synsets]
else:
self.class_labels = [self.synset2idx[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
}
if self.process_images:
self.size = retrieve(self.config, "size", default=256)
self.data = ImagePaths(self.abspaths,
labels=labels,
size=self.size,
random_crop=self.random_crop,
)
else:
self.data = self.abspaths
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [
"ILSVRC2012_img_train.tar",
]
SIZES = [
147897477120,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.process_images = process_images
self.data_root = data_root
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
]
SIZES = [
6744924160,
1950000,
]
def __init__(self, process_images=True, data_root=None, **kwargs):
self.data_root = data_root
self.process_images = process_images
super().__init__(**kwargs)
def _prepare(self):
if self.data_root:
self.root = os.path.join(self.data_root, self.NAME)
else:
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
if not tdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
tdu.mark_prepared(self.root)
class ImageNetSR(Dataset):
def __init__(self, size=None,
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
random_crop=True):
"""
Imagenet Superresolution Dataloader
Performs following ops in order:
1. crops a crop of size s from image either as random or center crop
2. resizes crop to size with cv2.area_interpolation
3. degrades resized crop with degradation_fn
:param size: resizing to size after cropping
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
:param downscale_f: Low Resolution Downsample factor
:param min_crop_f: determines crop size s,
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
:param max_crop_f: ""
:param data_root:
:param random_crop:
"""
self.base = self.get_base()
assert size
assert (size / downscale_f).is_integer()
self.size = size
self.LR_size = int(size / downscale_f)
self.min_crop_f = min_crop_f
self.max_crop_f = max_crop_f
assert(max_crop_f <= 1.)
self.center_crop = not random_crop
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
if degradation == "bsrgan":
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
elif degradation == "bsrgan_light":
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
else:
interpolation_fn = {
"cv_nearest": cv2.INTER_NEAREST,
"cv_bilinear": cv2.INTER_LINEAR,
"cv_bicubic": cv2.INTER_CUBIC,
"cv_area": cv2.INTER_AREA,
"cv_lanczos": cv2.INTER_LANCZOS4,
"pil_nearest": PIL.Image.NEAREST,
"pil_bilinear": PIL.Image.BILINEAR,
"pil_bicubic": PIL.Image.BICUBIC,
"pil_box": PIL.Image.BOX,
"pil_hamming": PIL.Image.HAMMING,
"pil_lanczos": PIL.Image.LANCZOS,
}[degradation]
self.pil_interpolation = degradation.startswith("pil_")
if self.pil_interpolation:
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
else:
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
interpolation=interpolation_fn)
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
min_side_len = min(image.shape[:2])
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
crop_side_len = int(crop_side_len)
if self.center_crop:
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
else:
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
image = self.cropper(image=image)["image"]
image = self.image_rescaler(image=image)["image"]
if self.pil_interpolation:
image_pil = PIL.Image.fromarray(image)
LR_image = self.degradation_process(image_pil)
LR_image = np.array(LR_image).astype(np.uint8)
else:
LR_image = self.degradation_process(image=image)["image"]
example["image"] = (image/127.5 - 1.0).astype(np.float32)
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
return example
class ImageNetSRTrain(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_train_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetTrain(process_images=False,)
return Subset(dset, indices)
class ImageNetSRValidation(ImageNetSR):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_base(self):
with open("data/imagenet_val_hr_indices.p", "rb") as f:
indices = pickle.load(f)
dset = ImageNetValidation(process_images=False,)
return Subset(dset, indices)
================================================
FILE: ldm/data/lsun.py
================================================
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class LSUNBase(Dataset):
def __init__(self,
txt_file,
data_root,
size=None,
interpolation="bicubic",
flip_p=0.5
):
self.data_paths = txt_file
self.data_root = data_root
with open(self.data_paths, "r") as f:
self.image_paths = f.read().splitlines()
self._length = len(self.image_paths)
self.labels = {
"relative_file_path_": [l for l in self.image_paths],
"file_path_": [os.path.join(self.data_root, l)
for l in self.image_paths],
}
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = dict((k, self.labels[k][i]) for k in self.labels)
image = Image.open(example["file_path_"])
if not image.mode == "RGB":
image = image.convert("RGB")
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
class LSUNChurchesTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
class LSUNChurchesValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
flip_p=flip_p, **kwargs)
class LSUNBedroomsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
class LSUNBedroomsValidation(LSUNBase):
def __init__(self, flip_p=0.0, **kwargs):
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
flip_p=flip_p, **kwargs)
class LSUNCatsTrain(LSUNBase):
def __init__(self, **kwargs):
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
class LSUNCatsValidation(LSUNBase):
def __init__(self, flip_p=0., **kwargs):
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
flip_p=flip_p, **kwargs)
================================================
FILE: ldm/data/personalized.py
================================================
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
training_templates_smallest = [
'photo of a sks {}',
]
reg_templates_smallest = [
'photo of a {}',
]
imagenet_templates_small = [
'a photo of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a photo of a clean {}',
'a photo of a dirty {}',
'a dark photo of the {}',
'a photo of my {}',
'a photo of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a photo of the {}',
'a good photo of the {}',
'a photo of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a photo of the clean {}',
'a rendition of a {}',
'a photo of a nice {}',
'a good photo of a {}',
'a photo of the nice {}',
'a photo of the small {}',
'a photo of the weird {}',
'a photo of the large {}',
'a photo of a cool {}',
'a photo of a small {}',
'an illustration of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'an illustration of a clean {}',
'an illustration of a dirty {}',
'a dark photo of the {}',
'an illustration of my {}',
'an illustration of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'an illustration of the {}',
'a good photo of the {}',
'an illustration of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'an illustration of the clean {}',
'a rendition of a {}',
'an illustration of a nice {}',
'a good photo of a {}',
'an illustration of the nice {}',
'an illustration of the small {}',
'an illustration of the weird {}',
'an illustration of the large {}',
'an illustration of a cool {}',
'an illustration of a small {}',
'a depiction of a {}',
'a rendering of a {}',
'a cropped photo of the {}',
'the photo of a {}',
'a depiction of a clean {}',
'a depiction of a dirty {}',
'a dark photo of the {}',
'a depiction of my {}',
'a depiction of the cool {}',
'a close-up photo of a {}',
'a bright photo of the {}',
'a cropped photo of a {}',
'a depiction of the {}',
'a good photo of the {}',
'a depiction of one {}',
'a close-up photo of the {}',
'a rendition of the {}',
'a depiction of the clean {}',
'a rendition of a {}',
'a depiction of a nice {}',
'a good photo of a {}',
'a depiction of the nice {}',
'a depiction of the small {}',
'a depiction of the weird {}',
'a depiction of the large {}',
'a depiction of a cool {}',
'a depiction of a small {}',
]
imagenet_dual_templates_small = [
'a photo of a {} with {}',
'a rendering of a {} with {}',
'a cropped photo of the {} with {}',
'the photo of a {} with {}',
'a photo of a clean {} with {}',
'a photo of a dirty {} with {}',
'a dark photo of the {} with {}',
'a photo of my {} with {}',
'a photo of the cool {} with {}',
'a close-up photo of a {} with {}',
'a bright photo of the {} with {}',
'a cropped photo of a {} with {}',
'a photo of the {} with {}',
'a good photo of the {} with {}',
'a photo of one {} with {}',
'a close-up photo of the {} with {}',
'a rendition of the {} with {}',
'a photo of the clean {} with {}',
'a rendition of a {} with {}',
'a photo of a nice {} with {}',
'a good photo of a {} with {}',
'a photo of the nice {} with {}',
'a photo of the small {} with {}',
'a photo of the weird {} with {}',
'a photo of the large {} with {}',
'a photo of a cool {} with {}',
'a photo of a small {} with {}',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="dog",
per_image_tokens=False,
center_crop=False,
mixing_prob=0.25,
coarse_class_text=None,
reg = False
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
self.mixing_prob = mixing_prob
self.coarse_class_text = coarse_class_text
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.reg = reg
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
placeholder_string = self.placeholder_token
if self.coarse_class_text:
placeholder_string = f"{self.coarse_class_text} {placeholder_string}"
if not self.reg:
text = random.choice(training_templates_smallest).format(placeholder_string)
else:
text = random.choice(reg_templates_smallest).format(placeholder_string)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
================================================
FILE: ldm/data/personalized_style.py
================================================
import os
import numpy as np
import PIL
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
imagenet_templates_small = [
'a painting in the style of {}',
'a rendering in the style of {}',
'a cropped painting in the style of {}',
'the painting in the style of {}',
'a clean painting in the style of {}',
'a dirty painting in the style of {}',
'a dark painting in the style of {}',
'a picture in the style of {}',
'a cool painting in the style of {}',
'a close-up painting in the style of {}',
'a bright painting in the style of {}',
'a cropped painting in the style of {}',
'a good painting in the style of {}',
'a close-up painting in the style of {}',
'a rendition in the style of {}',
'a nice painting in the style of {}',
'a small painting in the style of {}',
'a weird painting in the style of {}',
'a large painting in the style of {}',
]
imagenet_dual_templates_small = [
'a painting in the style of {} with {}',
'a rendering in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'the painting in the style of {} with {}',
'a clean painting in the style of {} with {}',
'a dirty painting in the style of {} with {}',
'a dark painting in the style of {} with {}',
'a cool painting in the style of {} with {}',
'a close-up painting in the style of {} with {}',
'a bright painting in the style of {} with {}',
'a cropped painting in the style of {} with {}',
'a good painting in the style of {} with {}',
'a painting of one {} in the style of {}',
'a nice painting in the style of {} with {}',
'a small painting in the style of {} with {}',
'a weird painting in the style of {} with {}',
'a large painting in the style of {} with {}',
]
per_img_token_list = [
'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת',
]
class PersonalizedBase(Dataset):
def __init__(self,
data_root,
size=None,
repeats=100,
interpolation="bicubic",
flip_p=0.5,
set="train",
placeholder_token="*",
per_image_tokens=False,
center_crop=False,
):
self.data_root = data_root
self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
# self._length = len(self.image_paths)
self.num_images = len(self.image_paths)
self._length = self.num_images
self.placeholder_token = placeholder_token
self.per_image_tokens = per_image_tokens
self.center_crop = center_crop
if per_image_tokens:
assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'."
if set == "train":
self._length = self.num_images * repeats
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
def __len__(self):
return self._length
def __getitem__(self, i):
example = {}
image = Image.open(self.image_paths[i % self.num_images])
if not image.mode == "RGB":
image = image.convert("RGB")
if self.per_image_tokens and np.random.uniform() < 0.25:
text = random.choice(imagenet_dual_templates_small).format(self.placeholder_token, per_img_token_list[i % self.num_images])
else:
text = random.choice(imagenet_templates_small).format(self.placeholder_token)
example["caption"] = text
# default to score-sde preprocessing
img = np.array(image).astype(np.uint8)
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = img.shape[0], img.shape[1]
img = img[(h - crop) // 2:(h + crop) // 2,
(w - crop) // 2:(w + crop) // 2]
image = Image.fromarray(img)
if self.size is not None:
image = image.resize((self.size, self.size), resample=self.interpolation)
image = self.flip(image)
image = np.array(image).astype(np.uint8)
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
return example
================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f
================================================
FILE: ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import Encoder, Decoder
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x
================================================
FILE: ldm/models/diffusion/__init__.py
================================================
================================================
FILE: ldm/models/diffusion/classifier.py
================================================
import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted
from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
__models__ = {
'class_label': EncoderUNetModel,
'segmentation': UNetModel
}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class NoisyLatentImageClassifier(pl.LightningModule):
def __init__(self,
diffusion_path,
num_classes,
ckpt_path=None,
pool='attention',
label_key=None,
diffusion_ckpt_path=None,
scheduler_config=None,
weight_decay=1.e-2,
log_steps=10,
monitor='val/loss',
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.num_classes = num_classes
# get latest config of diffusion model
diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
self.diffusion_config = OmegaConf.load(diffusion_config).model
self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
self.load_diffusion()
self.monitor = monitor
self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
self.log_steps = log_steps
self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
else self.diffusion_model.cond_stage_key
assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
if self.label_key not in __models__:
raise NotImplementedError()
self.load_classifier(ckpt_path, pool)
self.scheduler_config = scheduler_config
self.use_scheduler = self.scheduler_config is not None
self.weight_decay = weight_decay
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def load_diffusion(self):
model = instantiate_from_config(self.diffusion_config)
self.diffusion_model = model.eval()
self.diffusion_model.train = disabled_train
for param in self.diffusion_model.parameters():
param.requires_grad = False
def load_classifier(self, ckpt_path, pool):
model_config = deepcopy(self.diffusion_config.params.unet_config.params)
model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
model_config.out_channels = self.num_classes
if self.label_key == 'class_label':
model_config.pool = pool
self.model = __models__[self.label_key](**model_config)
if ckpt_path is not None:
print('#####################################################################')
print(f'load from ckpt "{ckpt_path}"')
print('#####################################################################')
self.init_from_ckpt(ckpt_path)
@torch.no_grad()
def get_x_noisy(self, x, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x))
continuous_sqrt_alpha_cumprod = None
if self.diffusion_model.use_continuous_noise:
continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
# todo: make sure t+1 is correct here
return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
def forward(self, x_noisy, t, *args, **kwargs):
return self.model(x_noisy, t)
@torch.no_grad()
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
@torch.no_grad()
def get_conditioning(self, batch, k=None):
if k is None:
k = self.label_key
assert k is not None, 'Needs to provide label key'
targets = batch[k].to(self.device)
if self.label_key == 'segmentation':
targets = rearrange(targets, 'b h w c -> b c h w')
for down in range(self.numd):
h, w = targets.shape[-2:]
targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
# targets = rearrange(targets,'b c h w -> b h w c')
return targets
def compute_top_k(self, logits, labels, k, reduction="mean"):
_, top_ks = torch.topk(logits, k, dim=1)
if reduction == "mean":
return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
elif reduction == "none":
return (top_ks == labels[:, None]).float().sum(dim=-1)
def on_train_epoch_start(self):
# save some memory
self.diffusion_model.model.to('cpu')
@torch.no_grad()
def write_logs(self, loss, logits, targets):
log_prefix = 'train' if self.training else 'val'
log = {}
log[f"{log_prefix}/loss"] = loss.mean()
log[f"{log_prefix}/acc@1"] = self.compute_top_k(
logits, targets, k=1, reduction="mean"
)
log[f"{log_prefix}/acc@5"] = self.compute_top_k(
logits, targets, k=5, reduction="mean"
)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
def shared_step(self, batch, t=None):
x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
targets = self.get_conditioning(batch)
if targets.dim() == 4:
targets = targets.argmax(dim=1)
if t is None:
t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
else:
t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
x_noisy = self.get_x_noisy(x, t)
logits = self(x_noisy, t)
loss = F.cross_entropy(logits, targets, reduction='none')
self.write_logs(loss.detach(), logits.detach(), targets.detach())
loss = loss.mean()
return loss, logits, x_noisy, targets
def training_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
return loss
def reset_noise_accs(self):
self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
def on_validation_start(self):
self.reset_noise_accs()
@torch.no_grad()
def validation_step(self, batch, batch_idx):
loss, *_ = self.shared_step(batch)
for t in self.noisy_acc:
_, logits, _, targets = self.shared_step(batch, t)
self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
if self.use_scheduler:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [optimizer], scheduler
return optimizer
@torch.no_grad()
def log_images(self, batch, N=8, *args, **kwargs):
log = dict()
x = self.get_input(batch, self.diffusion_model.first_stage_key)
log['inputs'] = x
y = self.get_conditioning(batch)
if self.label_key == 'class_label':
y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['labels'] = y
if ismap(y):
log['labels'] = self.diffusion_model.to_rgb(y)
for step in range(self.log_steps):
current_time = step * self.log_time_interval
_, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
log[f'inputs@t{current_time}'] = x_noisy
pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
pred = rearrange(pred, 'b h w c -> b c h w')
log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
for key in log:
log[key] = log[key][:N]
return log
================================================
FILE: ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec
================================================
FILE: ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import os
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
embedding_reg_weight=0.,
unfreeze_model=False,
model_lr=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
self.embedding_reg_weight = embedding_reg_weight
self.unfreeze_model = unfreeze_model
self.model_lr = model_lr
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
personalization_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
scale_by_std=False,
reg_weight = 1.0,
*args, **kwargs):
self.reg_weight = reg_weight
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer('scale_factor', torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.bbox_tokenizer = None
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
if not self.unfreeze_model:
self.cond_stage_model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
self.model.eval()
self.model.train = disabled_train
for param in self.model.parameters():
param.requires_grad = False
self.embedding_manager = None
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
# set rescale weight to 1./std of encodings
print("### USING STD-RESCALING ###")
x = super().get_input(batch, self.first_stage_key)
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
del self.scale_factor
self.register_buffer('scale_factor', 1. / z.flatten().std())
print(f"setting self.scale_factor to {self.scale_factor}")
print("### USING STD-RESCALING ###")
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
# def instantiate_embedding_manager(self, config, embedder):
# model = instantiate_from_config(config, embedder=embedder)
# if config.params.get("embedding_manager_ckpt", None): # do not load if missing OR empty string
# model.load(config.params.embedding_manager_ckpt)
# return model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
z = encoder_posterior.sample()
elif isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c, embedding_manager=self.embedding_manager)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
def meshgrid(self, h, w):
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
arr = torch.cat([y, x], dim=-1)
return arr
def delta_border(self, h, w):
"""
:param h: height
:param w: width
:return: normalized distance to image border,
wtith min distance = 0 at border and max dist = 0.5 at image center
"""
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
arr = self.meshgrid(h, w) / lower_right_corner
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
return edge_dist
def get_weighting(self, h, w, Ly, Lx, device):
weighting = self.delta_border(h, w)
weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
self.split_input_params["clip_max_weight"], )
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
if self.split_input_params["tie_braker"]:
L_weighting = self.delta_border(Ly, Lx)
L_weighting = torch.clip(L_weighting,
self.split_input_params["clip_min_tie_weight"],
self.split_input_params["clip_max_tie_weight"])
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
weighting = weighting * L_weighting
return weighting
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
"""
:param x: img of size (bs, c, h, w)
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
"""
bs, nc, h, w = x.shape
# number of crops in image
Ly = (h - kernel_size[0]) // stride[0] + 1
Lx = (w - kernel_size[1]) // stride[1] + 1
if uf == 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
elif uf > 1 and df == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
dilation=1, padding=0,
stride=(stride[0] * uf, stride[1] * uf))
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
elif df > 1 and uf == 1:
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
unfold = torch.nn.Unfold(**fold_params)
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
dilation=1, padding=0,
stride=(stride[0] // df, stride[1] // df))
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
else:
raise NotImplementedError
return fold, unfold, normalization, weighting
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox']:
xc = batch[cond_key]
elif cond_key == 'class_label':
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
# same as above but without decorator
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
uf = self.split_input_params["vqf"]
bs, nc, h, w = z.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
z = unfold(z) # (bn, nc * prod(**ks), L)
# 1. Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])]
else:
output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
o = o * weighting
# Reverse 1. reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization # norm is shape (1, 1, h, w)
return decoded
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
else:
if isinstance(self.first_stage_model, VQModelInterface):
return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
else:
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
if hasattr(self, "split_input_params"):
if self.split_input_params["patch_distributed_vq"]:
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
df = self.split_input_params["vqf"]
self.split_input_params['original_image_size'] = x.shape[-2:]
bs, nc, h, w = x.shape
if ks[0] > h or ks[1] > w:
ks = (min(ks[0], h), min(ks[1], w))
print("reducing Kernel")
if stride[0] > h or stride[1] > w:
stride = (min(stride[0], h), min(stride[1], w))
print("reducing stride")
fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
z = unfold(x) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
for i in range(z.shape[-1])]
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
decoded = fold(o)
decoded = decoded / normalization
return decoded
else:
return self.first_stage_model.encode(x)
else:
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def training_step(self, batch, batch_idx):
train_batch = batch[0]
reg_batch = batch[1]
loss_train, loss_dict = self.shared_step(train_batch)
loss_reg, _ = self.shared_step(reg_batch)
loss = loss_train + self.reg_weight * loss_reg
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
return [rescale_bbox(b) for b in bboxes]
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64)
h, w = x_noisy.shape[-2:]
fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
# Reshape to img shape
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
if self.cond_stage_key in ["image", "LR_image", "segmentation",
'bbox_img'] and self.model.conditioning_key: # todo check for completeness
c_key = next(iter(cond.keys())) # get key
c = next(iter(cond.values())) # get value
assert (len(c) == 1) # todo extend to list with more than one elem
c = c[0] # get element
c = unfold(c)
c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
elif self.cond_stage_key == 'coordinates_bbox':
assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
# assuming padding of unfold is always 0 and its dilation is always 1
n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
full_img_h, full_img_w = self.split_input_params['original_image_size']
# as we are operating on latents, we need the factor from the original image size to the
# spatial latent size to properly rescale the crops for regenerating the bbox annotations
num_downs = self.first_stage_model.encoder.num_resolutions - 1
rescale_latent = 2 ** (num_downs)
# get top left postions of patches as conforming for the bbbox tokenizer, therefore we
# need to rescale the tl patch coordinates to be in between (0,1)
tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
for patch_nr in range(z.shape[-1])]
# patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
patch_limits = [(x_tl, y_tl,
rescale_latent * ks[0] / full_img_w,
rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
# patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
# tokenize crop coordinates for the bounding boxes of the respective patches
patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
print(patch_limits_tknzd[0].shape)
# cut tknzd crop position from conditioning
assert isinstance(cond, dict), 'cond must be dict to be fed into model'
cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
print(cut_cond.shape)
adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
print(adapted_cond.shape)
adapted_cond = self.get_learned_conditioning(adapted_cond)
print(adapted_cond.shape)
adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
print(adapted_cond.shape)
cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
else:
cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
# apply model by loop over crops
output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
assert not isinstance(output_list[0],
tuple) # todo cant deal with multiple model outputs check this never happens
o = torch.stack(output_list, axis=-1)
o = o * weighting
# Reverse reshape to img shape
o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
# stitch crops together
x_recon = fold(o) / normalization
else:
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def _prior_bpd(self, x_start):
"""
Get the prior KL term for the variational lower-bound, measured in
bits-per-dim.
This term can't be optimized, as it only depends on the encoder.
:param x_start: the [N x C x ...] tensor of inputs.
:return: a batch of [N] KL values (in bits), one per batch element.
"""
batch_size = x_start.shape[0]
t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
return mean_flat(kl_prior) / np.log(2.0)
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
# if self.embedding_reg_weight > 0:
# loss_embedding_reg = self.embedding_manager.embedding_to_coarse_loss().mean()
# loss_dict.update({f'{prefix}/loss_emb_reg': loss_embedding_reg})
# loss += (self.embedding_reg_weight * loss_embedding_reg)
# loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
shape,cond,verbose=False,**kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True,**kwargs)
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,
plot_diffusion_rows=False, **kwargs):
use_ddim = ddim_steps is not None
log = dict()
batch = batch[0]
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
uc = self.get_learned_conditioning(len(c) * [""])
sample_scaled, _ = self.sample_log(cond=c,
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=5.0,
unconditional_conditioning=uc)
log["samples_scaled"] = self.decode_first_stage(sample_scaled)
if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
self.first_stage_model, IdentityFirstStage):
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
with self.ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with self.ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
if self.embedding_manager is not None: # If using textual inversion
embedding_params = list(self.embedding_manager.embedding_parameters())
if self.unfreeze_model: # Are we allowing the base model to train? If so, set two different parameter groups.
model_params = list(self.cond_stage_model.parameters()) + list(self.model.parameters())
opt = torch.optim.AdamW([{"params": embedding_params, "lr": lr}, {"params": model_params}], lr=self.model_lr)
else: # Otherwise, train only embedding
opt = torch.optim.AdamW(embedding_params, lr=lr)
else:
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
return opt
def configure_opt_embedding(self):
self.cond_stage_model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
self.model.eval()
self.model.train = disabled_train
for param in self.model.parameters():
param.requires_grad = False
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
lr = self.learning_rate
params = list(self.embedding_manager.embedding_parameters())
return torch.optim.AdamW(params, lr=lr)
def configure_opt_model(self):
for param in self.cond_stage_model.parameters():
param.requires_grad = True
for param in self.model.parameters():
param.requires_grad = True
for param in self.embedding_manager.embedding_parameters():
param.requires_grad = True
model_params = list(self.cond_stage_model.parameters()) + list(self.model.parameters())
embedding_params = list(self.embedding_manager.embedding_parameters())
return torch.optim.AdamW([{"params": embedding_params, "lr": self.learning_rate}, {"params": model_params}], lr=self.model_lr)
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
# @rank_zero_only
# def on_save_checkpoint(self, checkpoint):
# if not self.unfreeze_model: # If we are not tuning the model itself, zero-out the checkpoint content to preserve memory.
# checkpoint.clear()
# if os.path.isdir(self.trainer.checkpoint_callback.dirpath):
# self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, "embeddings.pt"))
# self.embedding_manager.save(os.path.join(self.trainer.checkpoint_callback.dirpath, f"embeddings_gs-{self.global_step}.pt"))
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
else:
raise NotImplementedError()
return out
class Layout2ImgDiffusion(LatentDiffusion):
# TODO: move all layout-specific hacks to this class
def __init__(self, cond_stage_key, *args, **kwargs):
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
def log_images(self, batch, N=8, *args, **kwargs):
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
key = 'train' if self.training else 'validation'
dset = self.trainer.datamodule.datasets[key]
mapper = dset.conditional_builders[self.cond_stage_key]
bbox_imgs = []
map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
for tknzd_bbox in batch[self.cond_stage_key][:N]:
bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
bbox_imgs.append(bboximg)
cond_img = torch.stack(bbox_imgs, dim=0)
logs['bbox_image'] = cond_img
return logs
================================================
FILE: ldm/models/diffusion/plms.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffus
gitextract_pv77kbwr/ ├── LICENSE ├── README.md ├── configs/ │ ├── autoencoder/ │ │ ├── autoencoder_kl_16x16x16.yaml │ │ ├── autoencoder_kl_32x32x4.yaml │ │ ├── autoencoder_kl_64x64x3.yaml │ │ └── autoencoder_kl_8x8x64.yaml │ ├── latent-diffusion/ │ │ ├── celebahq-ldm-vq-4.yaml │ │ ├── cin-ldm-vq-f8.yaml │ │ ├── cin256-v2.yaml │ │ ├── ffhq-ldm-vq-4.yaml │ │ ├── lsun_bedrooms-ldm-vq-4.yaml │ │ ├── lsun_churches-ldm-kl-8.yaml │ │ ├── txt2img-1p4B-eval.yaml │ │ ├── txt2img-1p4B-eval_with_tokens.yaml │ │ ├── txt2img-1p4B-finetune.yaml │ │ └── txt2img-1p4B-finetune_style.yaml │ └── stable-diffusion/ │ ├── v1-finetune.yaml │ ├── v1-finetune_unfrozen.yaml │ └── v1-inference.yaml ├── environment.yaml ├── evaluation/ │ └── clip_eval.py ├── ldm/ │ ├── data/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── imagenet.py │ │ ├── lsun.py │ │ ├── personalized.py │ │ └── personalized_style.py │ ├── lr_scheduler.py │ ├── models/ │ │ ├── autoencoder.py │ │ └── diffusion/ │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py │ ├── modules/ │ │ ├── attention.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions/ │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── embedding_manager.py │ │ ├── encoders/ │ │ │ ├── __init__.py │ │ │ ├── modules.py │ │ │ └── modules_bak.py │ │ ├── image_degradation/ │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ └── utils_image.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── main.py ├── merge_embeddings.py ├── models/ │ ├── first_stage_models/ │ │ ├── kl-f16/ │ │ │ └── config.yaml │ │ ├── kl-f32/ │ │ │ └── config.yaml │ │ ├── kl-f4/ │ │ │ └── config.yaml │ │ ├── kl-f8/ │ │ │ └── config.yaml │ │ ├── vq-f16/ │ │ │ └── config.yaml │ │ ├── vq-f4/ │ │ │ └── config.yaml │ │ ├── vq-f4-noattn/ │ │ │ └── config.yaml │ │ ├── vq-f8/ │ │ │ └── config.yaml │ │ └── vq-f8-n256/ │ │ └── config.yaml │ └── ldm/ │ ├── bsr_sr/ │ │ └── config.yaml │ ├── celeba256/ │ │ └── config.yaml │ ├── cin256/ │ │ └── config.yaml │ ├── ffhq256/ │ │ └── config.yaml │ ├── inpainting_big/ │ │ └── config.yaml │ ├── layout2img-openimages256/ │ │ └── config.yaml │ ├── lsun_beds256/ │ │ └── config.yaml │ ├── lsun_churches256/ │ │ └── config.yaml │ ├── semantic_synthesis256/ │ │ └── config.yaml │ ├── semantic_synthesis512/ │ │ └── config.yaml │ └── text2img256/ │ └── config.yaml ├── scripts/ │ ├── download_first_stages.sh │ ├── download_models.sh │ ├── evaluate_model.py │ ├── inpaint.py │ ├── latent_imagenet_diffusion.ipynb │ ├── sample_diffusion.py │ ├── stable_txt2img.py │ └── txt2img.py └── setup.py
SYMBOL INDEX (724 symbols across 35 files)
FILE: evaluation/clip_eval.py
class CLIPEvaluator (line 7) | class CLIPEvaluator(object):
method __init__ (line 8) | def __init__(self, device, clip_model='ViT-B/32') -> None:
method tokenize (line 18) | def tokenize(self, strings: list):
method encode_text (line 22) | def encode_text(self, tokens: list) -> torch.Tensor:
method encode_images (line 26) | def encode_images(self, images: torch.Tensor) -> torch.Tensor:
method get_text_features (line 30) | def get_text_features(self, text: str, norm: bool = True) -> torch.Ten...
method get_image_features (line 41) | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> ...
method img_to_img_similarity (line 49) | def img_to_img_similarity(self, src_images, generated_images):
method txt_to_img_similarity (line 55) | def txt_to_img_similarity(self, text, generated_images):
class LDMCLIPEvaluator (line 62) | class LDMCLIPEvaluator(CLIPEvaluator):
method __init__ (line 63) | def __init__(self, device, clip_model='ViT-B/32') -> None:
method evaluate (line 66) | def evaluate(self, ldm_model, src_images, target_text, n_samples=64, n...
class ImageDirEvaluator (line 104) | class ImageDirEvaluator(CLIPEvaluator):
method __init__ (line 105) | def __init__(self, device, clip_model='ViT-B/32') -> None:
method evaluate (line 108) | def evaluate(self, gen_samples, src_images, target_text):
FILE: ldm/data/base.py
class Txt2ImgIterableBaseDataset (line 5) | class Txt2ImgIterableBaseDataset(IterableDataset):
method __init__ (line 9) | def __init__(self, num_records=0, valid_ids=None, size=256):
method __len__ (line 18) | def __len__(self):
method __iter__ (line 22) | def __iter__(self):
FILE: ldm/data/imagenet.py
function synset2idx (line 20) | def synset2idx(path_to_yaml="data/index_synset.yaml"):
class ImageNetBase (line 26) | class ImageNetBase(Dataset):
method __init__ (line 27) | def __init__(self, config=None):
method __len__ (line 39) | def __len__(self):
method __getitem__ (line 42) | def __getitem__(self, i):
method _prepare (line 45) | def _prepare(self):
method _filter_relpaths (line 48) | def _filter_relpaths(self, relpaths):
method _prepare_synset_to_human (line 66) | def _prepare_synset_to_human(self):
method _prepare_idx_to_synset (line 74) | def _prepare_idx_to_synset(self):
method _prepare_human_to_integer_label (line 80) | def _prepare_human_to_integer_label(self):
method _load (line 93) | def _load(self):
class ImageNetTrain (line 134) | class ImageNetTrain(ImageNetBase):
method __init__ (line 145) | def __init__(self, process_images=True, data_root=None, **kwargs):
method _prepare (line 150) | def _prepare(self):
class ImageNetValidation (line 197) | class ImageNetValidation(ImageNetBase):
method __init__ (line 211) | def __init__(self, process_images=True, data_root=None, **kwargs):
method _prepare (line 216) | def _prepare(self):
class ImageNetSR (line 272) | class ImageNetSR(Dataset):
method __init__ (line 273) | def __init__(self, size=None,
method __len__ (line 336) | def __len__(self):
method __getitem__ (line 339) | def __getitem__(self, i):
class ImageNetSRTrain (line 375) | class ImageNetSRTrain(ImageNetSR):
method __init__ (line 376) | def __init__(self, **kwargs):
method get_base (line 379) | def get_base(self):
class ImageNetSRValidation (line 386) | class ImageNetSRValidation(ImageNetSR):
method __init__ (line 387) | def __init__(self, **kwargs):
method get_base (line 390) | def get_base(self):
FILE: ldm/data/lsun.py
class LSUNBase (line 9) | class LSUNBase(Dataset):
method __init__ (line 10) | def __init__(self,
method __len__ (line 36) | def __len__(self):
method __getitem__ (line 39) | def __getitem__(self, i):
class LSUNChurchesTrain (line 62) | class LSUNChurchesTrain(LSUNBase):
method __init__ (line 63) | def __init__(self, **kwargs):
class LSUNChurchesValidation (line 67) | class LSUNChurchesValidation(LSUNBase):
method __init__ (line 68) | def __init__(self, flip_p=0., **kwargs):
class LSUNBedroomsTrain (line 73) | class LSUNBedroomsTrain(LSUNBase):
method __init__ (line 74) | def __init__(self, **kwargs):
class LSUNBedroomsValidation (line 78) | class LSUNBedroomsValidation(LSUNBase):
method __init__ (line 79) | def __init__(self, flip_p=0.0, **kwargs):
class LSUNCatsTrain (line 84) | class LSUNCatsTrain(LSUNBase):
method __init__ (line 85) | def __init__(self, **kwargs):
class LSUNCatsValidation (line 89) | class LSUNCatsValidation(LSUNBase):
method __init__ (line 90) | def __init__(self, flip_p=0., **kwargs):
FILE: ldm/data/personalized.py
class PersonalizedBase (line 136) | class PersonalizedBase(Dataset):
method __init__ (line 137) | def __init__(self,
method __len__ (line 183) | def __len__(self):
method __getitem__ (line 186) | def __getitem__(self, i):
FILE: ldm/data/personalized_style.py
class PersonalizedBase (line 56) | class PersonalizedBase(Dataset):
method __init__ (line 57) | def __init__(self,
method __len__ (line 96) | def __len__(self):
method __getitem__ (line 99) | def __getitem__(self, i):
FILE: ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: ldm/models/autoencoder.py
class VQModel (line 14) | class VQModel(pl.LightningModule):
method __init__ (line 15) | def __init__(self,
method ema_scope (line 64) | def ema_scope(self, context=None):
method init_from_ckpt (line 78) | def init_from_ckpt(self, path, ignore_keys=list()):
method on_train_batch_end (line 92) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 96) | def encode(self, x):
method encode_to_prequant (line 102) | def encode_to_prequant(self, x):
method decode (line 107) | def decode(self, quant):
method decode_code (line 112) | def decode_code(self, code_b):
method forward (line 117) | def forward(self, input, return_pred_indices=False):
method get_input (line 124) | def get_input(self, batch, k):
method training_step (line 142) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 164) | def validation_step(self, batch, batch_idx):
method _validation_step (line 170) | def _validation_step(self, batch, batch_idx, suffix=""):
method configure_optimizers (line 197) | def configure_optimizers(self):
method get_last_layer (line 230) | def get_last_layer(self):
method log_images (line 233) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
method to_rgb (line 255) | def to_rgb(self, x):
class VQModelInterface (line 264) | class VQModelInterface(VQModel):
method __init__ (line 265) | def __init__(self, embed_dim, *args, **kwargs):
method encode (line 269) | def encode(self, x):
method decode (line 274) | def decode(self, h, force_not_quantize=False):
class AutoencoderKL (line 285) | class AutoencoderKL(pl.LightningModule):
method __init__ (line 286) | def __init__(self,
method init_from_ckpt (line 313) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 324) | def encode(self, x):
method decode (line 330) | def decode(self, z):
method forward (line 335) | def forward(self, input, sample_posterior=True):
method get_input (line 344) | def get_input(self, batch, k):
method training_step (line 351) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 372) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 386) | def configure_optimizers(self):
method get_last_layer (line 397) | def get_last_layer(self):
method log_images (line 401) | def log_images(self, batch, only_inputs=False, **kwargs):
method to_rgb (line 417) | def to_rgb(self, x):
class IdentityFirstStage (line 426) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 427) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 431) | def encode(self, x, *args, **kwargs):
method decode (line 434) | def decode(self, x, *args, **kwargs):
method quantize (line 437) | def quantize(self, x, *args, **kwargs):
method forward (line 442) | def forward(self, x, *args, **kwargs):
FILE: ldm/models/diffusion/classifier.py
function disabled_train (line 22) | def disabled_train(self, mode=True):
class NoisyLatentImageClassifier (line 28) | class NoisyLatentImageClassifier(pl.LightningModule):
method __init__ (line 30) | def __init__(self,
method init_from_ckpt (line 70) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method load_diffusion (line 88) | def load_diffusion(self):
method load_classifier (line 95) | def load_classifier(self, ckpt_path, pool):
method get_x_noisy (line 110) | def get_x_noisy(self, x, t, noise=None):
method forward (line 120) | def forward(self, x_noisy, t, *args, **kwargs):
method get_input (line 124) | def get_input(self, batch, k):
method get_conditioning (line 133) | def get_conditioning(self, batch, k=None):
method compute_top_k (line 150) | def compute_top_k(self, logits, labels, k, reduction="mean"):
method on_train_epoch_start (line 157) | def on_train_epoch_start(self):
method write_logs (line 162) | def write_logs(self, loss, logits, targets):
method shared_step (line 179) | def shared_step(self, batch, t=None):
method training_step (line 198) | def training_step(self, batch, batch_idx):
method reset_noise_accs (line 202) | def reset_noise_accs(self):
method on_validation_start (line 206) | def on_validation_start(self):
method validation_step (line 210) | def validation_step(self, batch, batch_idx):
method configure_optimizers (line 220) | def configure_optimizers(self):
method log_images (line 238) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: ldm/models/diffusion/ddim.py
class DDIMSampler (line 12) | class DDIMSampler(object):
method __init__ (line 13) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 19) | def register_buffer(self, name, attr):
method make_schedule (line 25) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 57) | def sample(self,
method ddim_sampling (line 114) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 166) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
method stochastic_encode (line 207) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 223) | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale...
FILE: ldm/models/diffusion/ddpm.py
function disabled_train (line 36) | def disabled_train(self, mode=True):
function uniform_on_device (line 42) | def uniform_on_device(r1, r2, shape, device):
class DDPM (line 46) | class DDPM(pl.LightningModule):
method __init__ (line 48) | def __init__(self,
method register_schedule (line 126) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 181) | def ema_scope(self, context=None):
method init_from_ckpt (line 195) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 213) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 225) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 231) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 240) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 253) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 262) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 277) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 283) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 288) | def get_loss(self, pred, target, mean=True):
method p_losses (line 303) | def p_losses(self, x_start, t, noise=None):
method forward (line 332) | def forward(self, x, *args, **kwargs):
method get_input (line 338) | def get_input(self, batch, k):
method shared_step (line 346) | def shared_step(self, batch):
method training_step (line 351) | def training_step(self, batch, batch_idx):
method validation_step (line 367) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 375) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 379) | def _get_rows_from_list(self, samples):
method log_images (line 387) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 424) | def configure_optimizers(self):
class LatentDiffusion (line 433) | class LatentDiffusion(DDPM):
method __init__ (line 435) | def __init__(self,
method make_cond_schedule (line 502) | def make_cond_schedule(self, ):
method on_train_batch_start (line 509) | def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
method register_schedule (line 525) | def register_schedule(self,
method instantiate_first_stage (line 534) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 541) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 571) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 583) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 592) | def get_learned_conditioning(self, c):
method meshgrid (line 605) | def meshgrid(self, h, w):
method delta_border (line 612) | def delta_border(self, h, w):
method get_weighting (line 626) | def get_weighting(self, h, w, Ly, Lx, device):
method get_fold_unfold (line 642) | def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo...
method get_input (line 695) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 747) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method differentiable_decode_first_stage (line 807) | def differentiable_decode_first_stage(self, z, predict_cids=False, for...
method encode_first_stage (line 867) | def encode_first_stage(self, x):
method shared_step (line 906) | def shared_step(self, batch, **kwargs):
method training_step (line 911) | def training_step(self, batch, batch_idx):
method forward (line 932) | def forward(self, x, c, *args, **kwargs):
method _rescale_annotations (line 944) | def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: mov...
method apply_model (line 954) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method _predict_eps_from_xstart (line 1057) | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
method _prior_bpd (line 1061) | def _prior_bpd(self, x_start):
method p_losses (line 1075) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1118) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1150) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1181) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1237) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1288) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 1306) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 1321) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1442) | def configure_optimizers(self):
method configure_opt_embedding (line 1466) | def configure_opt_embedding(self):
method configure_opt_model (line 1485) | def configure_opt_model(self):
method to_rgb (line 1501) | def to_rgb(self, x):
class DiffusionWrapper (line 1521) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1522) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1528) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class Layout2ImgDiffusion (line 1550) | class Layout2ImgDiffusion(LatentDiffusion):
method __init__ (line 1552) | def __init__(self, cond_stage_key, *args, **kwargs):
method log_images (line 1556) | def log_images(self, batch, N=8, *args, **kwargs):
FILE: ldm/models/diffusion/plms.py
class PLMSSampler (line 11) | class PLMSSampler(object):
method __init__ (line 12) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 58) | def sample(self,
method plms_sampling (line 115) | def plms_sampling(self, cond, shape,
method p_sample_plms (line 173) | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: ldm/modules/attention.py
function exists (line 11) | def exists(val):
function uniq (line 15) | def uniq(arr):
function default (line 19) | def default(val, d):
function max_neg_value (line 25) | def max_neg_value(t):
function init_ (line 29) | def init_(tensor):
class GEGLU (line 37) | class GEGLU(nn.Module):
method __init__ (line 38) | def __init__(self, dim_in, dim_out):
method forward (line 42) | def forward(self, x):
class FeedForward (line 47) | class FeedForward(nn.Module):
method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 63) | def forward(self, x):
function zero_module (line 67) | def zero_module(module):
function Normalize (line 76) | def Normalize(in_channels):
class LinearAttention (line 80) | class LinearAttention(nn.Module):
method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 88) | def forward(self, x):
class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels):
method forward (line 126) | def forward(self, x):
class CrossAttention (line 152) | class CrossAttention(nn.Module):
method __init__ (line 153) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 170) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 196) | class BasicTransformerBlock(nn.Module):
method __init__ (line 197) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 208) | def forward(self, x, context=None):
method _forward (line 211) | def _forward(self, x, context=None):
class SpatialTransformer (line 218) | class SpatialTransformer(nn.Module):
method __init__ (line 226) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 250) | def forward(self, x, context=None):
FILE: ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 12) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 33) | def nonlinearity(x):
function Normalize (line 38) | def Normalize(in_channels, num_groups=32):
class Upsample (line 42) | class Upsample(nn.Module):
method __init__ (line 43) | def __init__(self, in_channels, with_conv):
method forward (line 53) | def forward(self, x):
class Downsample (line 60) | class Downsample(nn.Module):
method __init__ (line 61) | def __init__(self, in_channels, with_conv):
method forward (line 72) | def forward(self, x):
class ResnetBlock (line 82) | class ResnetBlock(nn.Module):
method __init__ (line 83) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 121) | def forward(self, x, temb):
class LinAttnBlock (line 144) | class LinAttnBlock(LinearAttention):
method __init__ (line 146) | def __init__(self, in_channels):
class AttnBlock (line 150) | class AttnBlock(nn.Module):
method __init__ (line 151) | def __init__(self, in_channels):
method forward (line 178) | def forward(self, x):
function make_attn (line 205) | def make_attn(in_channels, attn_type="vanilla"):
class Model (line 216) | class Model(nn.Module):
method __init__ (line 217) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 316) | def forward(self, x, t=None, context=None):
method get_last_layer (line 364) | def get_last_layer(self):
class Encoder (line 368) | class Encoder(nn.Module):
method __init__ (line 369) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 434) | def forward(self, x):
class Decoder (line 462) | class Decoder(nn.Module):
method __init__ (line 463) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 535) | def forward(self, z):
class SimpleDecoder (line 571) | class SimpleDecoder(nn.Module):
method __init__ (line 572) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 594) | def forward(self, x):
class UpsampleDecoder (line 607) | class UpsampleDecoder(nn.Module):
method __init__ (line 608) | def __init__(self, in_channels, out_channels, ch, num_res_blocks, reso...
method forward (line 641) | def forward(self, x):
class LatentRescaler (line 655) | class LatentRescaler(nn.Module):
method __init__ (line 656) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 680) | def forward(self, x):
class MergedRescaleEncoder (line 692) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 693) | def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
method forward (line 705) | def forward(self, x):
class MergedRescaleDecoder (line 711) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 712) | def __init__(self, z_channels, out_ch, resolution, num_res_blocks, att...
method forward (line 722) | def forward(self, x):
class Upsampler (line 728) | class Upsampler(nn.Module):
method __init__ (line 729) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 741) | def forward(self, x):
class Resize (line 747) | class Resize(nn.Module):
method __init__ (line 748) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 763) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 770) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 772) | def __init__(self, ch_mult:list, in_channels,
method instantiate_pretrained (line 807) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 816) | def encode_with_pretrained(self,x):
method forward (line 822) | def forward(self,x):
FILE: ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 24) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 27) | def convert_module_to_f32(x):
class AttentionPool2d (line 32) | class AttentionPool2d(nn.Module):
method __init__ (line 37) | def __init__(
method forward (line 51) | def forward(self, x):
class TimestepBlock (line 62) | class TimestepBlock(nn.Module):
method forward (line 68) | def forward(self, x, emb):
class TimestepEmbedSequential (line 74) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 80) | def forward(self, x, emb, context=None):
class Upsample (line 91) | class Upsample(nn.Module):
method __init__ (line 100) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 109) | def forward(self, x):
class TransposedUpsample (line 121) | class TransposedUpsample(nn.Module):
method __init__ (line 123) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 130) | def forward(self,x):
class Downsample (line 134) | class Downsample(nn.Module):
method __init__ (line 143) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 158) | def forward(self, x):
class ResBlock (line 163) | class ResBlock(TimestepBlock):
method __init__ (line 179) | def __init__(
method forward (line 243) | def forward(self, x, emb):
method _forward (line 255) | def _forward(self, x, emb):
class AttentionBlock (line 278) | class AttentionBlock(nn.Module):
method __init__ (line 285) | def __init__(
method forward (line 314) | def forward(self, x):
method _forward (line 318) | def _forward(self, x):
function count_flops_attn (line 327) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 347) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 352) | def __init__(self, n_heads):
method forward (line 356) | def forward(self, qkv):
method count_flops (line 375) | def count_flops(model, _x, y):
class QKVAttention (line 379) | class QKVAttention(nn.Module):
method __init__ (line 384) | def __init__(self, n_heads):
method forward (line 388) | def forward(self, qkv):
method count_flops (line 409) | def count_flops(model, _x, y):
class UNetModel (line 413) | class UNetModel(nn.Module):
method __init__ (line 443) | def __init__(
method convert_to_fp16 (line 694) | def convert_to_fp16(self):
method convert_to_fp32 (line 702) | def convert_to_fp32(self):
method forward (line 710) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
class EncoderUNetModel (line 745) | class EncoderUNetModel(nn.Module):
method __init__ (line 751) | def __init__(
method convert_to_fp16 (line 924) | def convert_to_fp16(self):
method convert_to_fp32 (line 931) | def convert_to_fp32(self):
method forward (line 938) | def forward(self, x, timesteps):
FILE: ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 131) | def backward(ctx, *output_grads):
function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 174) | def zero_module(module):
function scale_module (line 183) | def scale_module(module, scale):
function mean_flat (line 192) | def mean_flat(tensor):
function normalization (line 199) | def normalization(channels):
class SiLU (line 209) | class SiLU(nn.Module):
method forward (line 210) | def forward(self, x):
class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
method forward (line 215) | def forward(self, x):
function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
function linear (line 231) | def linear(*args, **kwargs):
function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 258) | def forward(self, c_concat, c_crossattn):
function noise_like (line 264) | def noise_like(shape, device, repeat=False):
FILE: ldm/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 35) | def sample(self):
method kl (line 39) | def kl(self, other=None):
method nll (line 53) | def nll(self, sample, dims=[1,2,3]):
method mode (line 61) | def mode(self):
function normal_kl (line 65) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: ldm/modules/embedding_manager.py
function get_clip_token_for_string (line 12) | def get_clip_token_for_string(tokenizer, string):
function get_bert_token_for_string (line 20) | def get_bert_token_for_string(tokenizer, string):
function get_embedding_for_clip_token (line 28) | def get_embedding_for_clip_token(embedder, token):
class EmbeddingManager (line 32) | class EmbeddingManager(nn.Module):
method __init__ (line 33) | def __init__(
method forward (line 88) | def forward(
method save (line 131) | def save(self, ckpt_path):
method load (line 135) | def load(self, ckpt_path):
method get_embedding_norms_squared (line 141) | def get_embedding_norms_squared(self):
method embedding_parameters (line 147) | def embedding_parameters(self):
method embedding_to_coarse_loss (line 150) | def embedding_to_coarse_loss(self):
FILE: ldm/modules/encoders/modules.py
function _expand_mask (line 11) | def _expand_mask(mask, dtype, tgt_len = None):
function _build_causal_attention_mask (line 24) | def _build_causal_attention_mask(bsz, seq_len, dtype):
class AbstractEncoder (line 33) | class AbstractEncoder(nn.Module):
method __init__ (line 34) | def __init__(self):
method encode (line 37) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 42) | class ClassEmbedder(nn.Module):
method __init__ (line 43) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 48) | def forward(self, batch, key=None):
class TransformerEmbedder (line 57) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 59) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 65) | def forward(self, tokens):
method encode (line 70) | def encode(self, x):
class BERTTokenizer (line 74) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 76) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 84) | def forward(self, text):
method encode (line 91) | def encode(self, text):
method decode (line 97) | def decode(self, text):
class BERTEmbedder (line 101) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 103) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 114) | def forward(self, text, embedding_manager=None):
method encode (line 122) | def encode(self, text, **kwargs):
class SpatialRescaler (line 126) | class SpatialRescaler(nn.Module):
method __init__ (line 127) | def __init__(self,
method forward (line 145) | def forward(self,x):
method encode (line 154) | def encode(self, x):
class FrozenCLIPEmbedder (line 157) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 159) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 310) | def freeze(self):
method forward (line 315) | def forward(self, text, **kwargs):
method encode (line 323) | def encode(self, text, **kwargs):
class FrozenCLIPTextEmbedder (line 327) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 331) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 339) | def freeze(self):
method forward (line 344) | def forward(self, text):
method encode (line 351) | def encode(self, text):
class FrozenClipImageEmbedder (line 359) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 363) | def __init__(
method preprocess (line 378) | def preprocess(self, x):
method forward (line 388) | def forward(self, x):
FILE: ldm/modules/encoders/modules_bak.py
function _expand_mask (line 11) | def _expand_mask(mask, dtype, tgt_len = None):
function _build_causal_attention_mask (line 24) | def _build_causal_attention_mask(bsz, seq_len, dtype):
class AbstractEncoder (line 33) | class AbstractEncoder(nn.Module):
method __init__ (line 34) | def __init__(self):
method encode (line 37) | def encode(self, *args, **kwargs):
class ClassEmbedder (line 42) | class ClassEmbedder(nn.Module):
method __init__ (line 43) | def __init__(self, embed_dim, n_classes=1000, key='class'):
method forward (line 48) | def forward(self, batch, key=None):
class TransformerEmbedder (line 57) | class TransformerEmbedder(AbstractEncoder):
method __init__ (line 59) | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, devic...
method forward (line 65) | def forward(self, tokens):
method encode (line 70) | def encode(self, x):
class BERTTokenizer (line 74) | class BERTTokenizer(AbstractEncoder):
method __init__ (line 76) | def __init__(self, device="cuda", vq_interface=True, max_length=77):
method forward (line 84) | def forward(self, text):
method encode (line 91) | def encode(self, text):
method decode (line 97) | def decode(self, text):
class BERTEmbedder (line 101) | class BERTEmbedder(AbstractEncoder):
method __init__ (line 103) | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
method forward (line 114) | def forward(self, text, embedding_manager=None):
method encode (line 122) | def encode(self, text, **kwargs):
class SpatialRescaler (line 126) | class SpatialRescaler(nn.Module):
method __init__ (line 127) | def __init__(self,
method forward (line 145) | def forward(self,x):
method encode (line 154) | def encode(self, x):
class FrozenCLIPEmbedder (line 157) | class FrozenCLIPEmbedder(AbstractEncoder):
method __init__ (line 159) | def __init__(self, version="openai/clip-vit-large-patch14", device="cu...
method freeze (line 410) | def freeze(self):
method forward (line 415) | def forward(self, text, **kwargs):
method encode (line 423) | def encode(self, text, **kwargs):
class FrozenCLIPTextEmbedder (line 427) | class FrozenCLIPTextEmbedder(nn.Module):
method __init__ (line 431) | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n...
method freeze (line 439) | def freeze(self):
method forward (line 444) | def forward(self, text):
method encode (line 451) | def encode(self, text):
class FrozenClipImageEmbedder (line 459) | class FrozenClipImageEmbedder(nn.Module):
method __init__ (line 463) | def __init__(
method preprocess (line 478) | def preprocess(self, x):
method forward (line 488) | def forward(self, x):
FILE: ldm/modules/image_degradation/bsrgan.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 339) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 369) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 386) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 404) | def add_Poisson_noise(img):
function add_JPEG_noise (line 418) | def add_JPEG_noise(img):
function random_crop (line 427) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 438) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 530) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
function degradation_bsrgan_plus (line 617) | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True,...
FILE: ldm/modules/image_degradation/bsrgan_light.py
function modcrop_np (line 29) | def modcrop_np(img, sf):
function analytic_kernel (line 49) | def analytic_kernel(k):
function anisotropic_Gaussian (line 65) | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
function gm_blur_kernel (line 86) | def gm_blur_kernel(mean, cov, size=15):
function shift_pixel (line 99) | def shift_pixel(x, sf, upper_left=True):
function blur (line 128) | def blur(x, k):
function gen_kernel (line 145) | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]),...
function fspecial_gaussian (line 187) | def fspecial_gaussian(hsize, sigma):
function fspecial_laplacian (line 201) | def fspecial_laplacian(alpha):
function fspecial (line 210) | def fspecial(filter_type, *args, **kwargs):
function bicubic_degradation (line 228) | def bicubic_degradation(x, sf=3):
function srmd_degradation (line 240) | def srmd_degradation(x, k, sf=3):
function dpsr_degradation (line 262) | def dpsr_degradation(x, k, sf=3):
function classical_degradation (line 284) | def classical_degradation(x, k, sf=3):
function add_sharpening (line 299) | def add_sharpening(img, weight=0.5, radius=50, threshold=10):
function add_blur (line 325) | def add_blur(img, sf=4):
function add_resize (line 343) | def add_resize(img, sf=4):
function add_Gaussian_noise (line 373) | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
function add_speckle_noise (line 390) | def add_speckle_noise(img, noise_level1=2, noise_level2=25):
function add_Poisson_noise (line 408) | def add_Poisson_noise(img):
function add_JPEG_noise (line 422) | def add_JPEG_noise(img):
function random_crop (line 431) | def random_crop(lq, hq, sf=4, lq_patchsize=64):
function degradation_bsrgan (line 442) | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
function degradation_bsrgan_variant (line 534) | def degradation_bsrgan_variant(image, sf=4, isp_model=None):
FILE: ldm/modules/image_degradation/utils_image.py
function is_image_file (line 29) | def is_image_file(filename):
function get_timestamp (line 33) | def get_timestamp():
function imshow (line 37) | def imshow(x, title=None, cbar=False, figsize=None):
function surf (line 47) | def surf(Z, cmap='rainbow', figsize=None):
function get_image_paths (line 67) | def get_image_paths(dataroot):
function _get_paths_from_images (line 74) | def _get_paths_from_images(path):
function patches_from_image (line 93) | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
function imssave (line 112) | def imssave(imgs, img_path):
function split_imageset (line 125) | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_si...
function mkdir (line 153) | def mkdir(path):
function mkdirs (line 158) | def mkdirs(paths):
function mkdir_and_rename (line 166) | def mkdir_and_rename(path):
function imread_uint (line 185) | def imread_uint(path, n_channels=3):
function imsave (line 203) | def imsave(img, img_path):
function imwrite (line 209) | def imwrite(img, img_path):
function read_img (line 220) | def read_img(path):
function uint2single (line 249) | def uint2single(img):
function single2uint (line 254) | def single2uint(img):
function uint162single (line 259) | def uint162single(img):
function single2uint16 (line 264) | def single2uint16(img):
function uint2tensor4 (line 275) | def uint2tensor4(img):
function uint2tensor3 (line 282) | def uint2tensor3(img):
function tensor2uint (line 289) | def tensor2uint(img):
function single2tensor3 (line 302) | def single2tensor3(img):
function single2tensor4 (line 307) | def single2tensor4(img):
function tensor2single (line 312) | def tensor2single(img):
function tensor2single3 (line 320) | def tensor2single3(img):
function single2tensor5 (line 329) | def single2tensor5(img):
function single32tensor5 (line 333) | def single32tensor5(img):
function single42tensor4 (line 337) | def single42tensor4(img):
function tensor2img (line 342) | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
function augment_img (line 380) | def augment_img(img, mode=0):
function augment_img_tensor4 (line 401) | def augment_img_tensor4(img, mode=0):
function augment_img_tensor (line 422) | def augment_img_tensor(img, mode=0):
function augment_img_np3 (line 441) | def augment_img_np3(img, mode=0):
function augment_imgs (line 469) | def augment_imgs(img_list, hflip=True, rot=True):
function modcrop (line 494) | def modcrop(img_in, scale):
function shave (line 510) | def shave(img_in, border=0):
function rgb2ycbcr (line 529) | def rgb2ycbcr(img, only_y=True):
function ycbcr2rgb (line 553) | def ycbcr2rgb(img):
function bgr2ycbcr (line 573) | def bgr2ycbcr(img, only_y=True):
function channel_convert (line 597) | def channel_convert(in_c, tar_type, img_list):
function calculate_psnr (line 621) | def calculate_psnr(img1, img2, border=0):
function calculate_ssim (line 642) | def calculate_ssim(img1, img2, border=0):
function ssim (line 669) | def ssim(img1, img2):
function cubic (line 700) | def cubic(x):
function calculate_weights_indices (line 708) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
function imresize (line 766) | def imresize(img, scale, antialiasing=True):
function imresize_np (line 839) | def imresize_np(img, scale, antialiasing=True):
FILE: ldm/modules/losses/contperceptual.py
class LPIPSWithDiscriminator (line 7) | class LPIPSWithDiscriminator(nn.Module):
method __init__ (line 8) | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixello...
method calculate_adaptive_weight (line 32) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 45) | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
FILE: ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: ldm/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 25) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 26) | def __init__(self, dim, max_seq_len):
method init_ (line 31) | def init_(self):
method forward (line 34) | def forward(self, x):
class FixedPositionalEmbedding (line 39) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 40) | def __init__(self, dim):
method forward (line 45) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 54) | def exists(val):
function default (line 58) | def default(val, d):
function always (line 64) | def always(val):
function not_equals (line 70) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 82) | def max_neg_value(tensor):
function pick_and_pop (line 88) | def pick_and_pop(keys, d):
function group_dict_by_key (line 93) | def group_dict_by_key(cond, d):
function string_begins_with (line 102) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 106) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 110) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 117) | class Scale(nn.Module):
method __init__ (line 118) | def __init__(self, value, fn):
method forward (line 123) | def forward(self, x, **kwargs):
class Rezero (line 128) | class Rezero(nn.Module):
method __init__ (line 129) | def __init__(self, fn):
method forward (line 134) | def forward(self, x, **kwargs):
class ScaleNorm (line 139) | class ScaleNorm(nn.Module):
method __init__ (line 140) | def __init__(self, dim, eps=1e-5):
method forward (line 146) | def forward(self, x):
class RMSNorm (line 151) | class RMSNorm(nn.Module):
method __init__ (line 152) | def __init__(self, dim, eps=1e-8):
method forward (line 158) | def forward(self, x):
class Residual (line 163) | class Residual(nn.Module):
method forward (line 164) | def forward(self, x, residual):
class GRUGating (line 168) | class GRUGating(nn.Module):
method __init__ (line 169) | def __init__(self, dim):
method forward (line 173) | def forward(self, x, residual):
class GEGLU (line 184) | class GEGLU(nn.Module):
method __init__ (line 185) | def __init__(self, dim_in, dim_out):
method forward (line 189) | def forward(self, x):
class FeedForward (line 194) | class FeedForward(nn.Module):
method __init__ (line 195) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 210) | def forward(self, x):
class Attention (line 215) | class Attention(nn.Module):
method __init__ (line 216) | def __init__(
method forward (line 268) | def forward(
class AttentionLayers (line 370) | class AttentionLayers(nn.Module):
method __init__ (line 371) | def __init__(
method forward (line 481) | def forward(
class Encoder (line 542) | class Encoder(AttentionLayers):
method __init__ (line 543) | def __init__(self, **kwargs):
class TransformerWrapper (line 549) | class TransformerWrapper(nn.Module):
method __init__ (line 550) | def __init__(
method init_ (line 596) | def init_(self):
method forward (line 599) | def forward(
FILE: ldm/util.py
function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 41) | def ismap(x):
function isimage (line 47) | def isimage(x):
function exists (line 53) | def exists(x):
function default (line 57) | def default(val, d):
function mean_flat (line 63) | def mean_flat(tensor):
function count_params (line 71) | def count_params(model, verbose=False):
function instantiate_from_config (line 78) | def instantiate_from_config(config, **kwargs):
function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 108) | def parallel_data_prefetch(
FILE: main.py
function load_model_from_config (line 24) | def load_model_from_config(config, ckpt, verbose=False):
function get_parser (line 41) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 181) | def nondefault_trainer_args(opt):
class WrappedDataset (line 188) | class WrappedDataset(Dataset):
method __init__ (line 191) | def __init__(self, dataset):
method __len__ (line 194) | def __len__(self):
method __getitem__ (line 197) | def __getitem__(self, idx):
function worker_init_fn (line 201) | def worker_init_fn(_):
class ConcatDataset (line 216) | class ConcatDataset(Dataset):
method __init__ (line 217) | def __init__(self, *datasets):
method __getitem__ (line 220) | def __getitem__(self, idx):
method __len__ (line 223) | def __len__(self):
class DataModuleFromConfig (line 226) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 227) | def __init__(self, batch_size, train=None, reg = None, validation=None...
method prepare_data (line 253) | def prepare_data(self):
method setup (line 257) | def setup(self, stage=None):
method _train_dataloader (line 265) | def _train_dataloader(self):
method _val_dataloader (line 278) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 289) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 302) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 311) | class SetupCallback(Callback):
method __init__ (line 312) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 322) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_pretrain_routine_start (line 328) | def on_pretrain_routine_start(self, trainer, pl_module):
class ImageLogger (line 360) | class ImageLogger(Callback):
method __init__ (line 361) | def __init__(self, batch_frequency, max_images, clamp=True, increase_l...
method _testtube (line 381) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 392) | def log_local(self, save_dir, split, images,
method log_img (line 411) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 443) | def check_frequency(self, check_idx):
method on_train_batch_end (line 454) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 458) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
class CUDACallback (line 466) | class CUDACallback(Callback):
method on_train_epoch_start (line 468) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 474) | def on_train_epoch_end(self, trainer, pl_module):
class ModeSwapCallback (line 488) | class ModeSwapCallback(Callback):
method __init__ (line 490) | def __init__(self, swap_step=2000):
method on_train_epoch_start (line 495) | def on_train_epoch_start(self, trainer, pl_module):
function melk (line 808) | def melk(*args, **kwargs):
function divein (line 816) | def divein(*args, **kwargs):
FILE: merge_embeddings.py
function get_placeholder_loop (line 9) | def get_placeholder_loop(placeholder_string, embedder, is_sd):
function get_clip_token_for_string (line 24) | def get_clip_token_for_string(tokenizer, string):
function get_bert_token_for_string (line 34) | def get_bert_token_for_string(tokenizer, string):
FILE: scripts/evaluate_model.py
function load_model_from_config (line 19) | def load_model_from_config(config, ckpt, verbose=False):
FILE: scripts/inpaint.py
function make_batch (line 11) | def make_batch(image, mask, device):
FILE: scripts/sample_diffusion.py
function custom_to_pil (line 15) | def custom_to_pil(x):
function custom_to_np (line 27) | def custom_to_np(x):
function logs2pil (line 36) | def logs2pil(logs, keys=["sample"]):
function convsample (line 54) | def convsample(model, shape, return_intermediates=True,
function convsample_ddim (line 69) | def convsample_ddim(model, steps, shape, eta=1.0
function make_convolutional_sample (line 79) | def make_convolutional_sample(model, batch_size, vanilla=False, custom_s...
function run (line 108) | def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, ...
function save_logs (line 143) | def save_logs(logs, path, n_saved=0, key="sample", np_path=None):
function get_parser (line 162) | def get_parser():
function load_model_from_config (line 220) | def load_model_from_config(config, sd):
function load_model (line 228) | def load_model(config, ckpt, gpu, eval_mode):
FILE: scripts/stable_txt2img.py
function chunk (line 20) | def chunk(it, size):
function load_model_from_config (line 25) | def load_model_from_config(config, ckpt, verbose=False):
function main (line 45) | def main():
FILE: scripts/txt2img.py
function load_model_from_config (line 14) | def load_model_from_config(config, ckpt, verbose=False):
Condensed preview — 86 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (4,752K chars).
[
{
"path": "LICENSE",
"chars": 1125,
"preview": "MIT License\n\nCopyright (c) 2022 Rinon Gal, Yuval Alaluf, Yuval Atzmon, Or Patashnik and contributors\n\nPermission is here"
},
{
"path": "README.md",
"chars": 7197,
"preview": "# Dreambooth on Stable Diffusion\n\nThis is an implementtaion of Google's [Dreambooth](https://arxiv.org/abs/2208.12242) w"
},
{
"path": "configs/autoencoder/autoencoder_kl_16x16x16.yaml",
"chars": 1145,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "configs/autoencoder/autoencoder_kl_32x32x4.yaml",
"chars": 1140,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "configs/autoencoder/autoencoder_kl_64x64x3.yaml",
"chars": 1139,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "configs/autoencoder/autoencoder_kl_8x8x64.yaml",
"chars": 1148,
"preview": "model:\n base_learning_rate: 4.5e-6\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: \"val/rec_loss\""
},
{
"path": "configs/latent-diffusion/celebahq-ldm-vq-4.yaml",
"chars": 2028,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/cin-ldm-vq-f8.yaml",
"chars": 2360,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/cin256-v2.yaml",
"chars": 1553,
"preview": "model:\n base_learning_rate: 0.0001\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.00"
},
{
"path": "configs/latent-diffusion/ffhq-ldm-vq-4.yaml",
"chars": 2020,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml",
"chars": 2024,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml",
"chars": 2284,
"preview": "model:\n base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'\n target: ldm.model"
},
{
"path": "configs/latent-diffusion/txt2img-1p4B-eval.yaml",
"chars": 1614,
"preview": "model:\n base_learning_rate: 5.0e-05\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml",
"chars": 1783,
"preview": "model:\n base_learning_rate: 5.0e-05\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/latent-diffusion/txt2img-1p4B-finetune.yaml",
"chars": 2652,
"preview": "model:\n base_learning_rate: 5.0e-3\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.00"
},
{
"path": "configs/latent-diffusion/txt2img-1p4B-finetune_style.yaml",
"chars": 2610,
"preview": "model:\n base_learning_rate: 5.0e-3\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.00"
},
{
"path": "configs/stable-diffusion/v1-finetune.yaml",
"chars": 2649,
"preview": "model:\n base_learning_rate: 5.0e-03\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "configs/stable-diffusion/v1-finetune_unfrozen.yaml",
"chars": 2863,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n reg_weight: 1.0\n "
},
{
"path": "configs/stable-diffusion/v1-inference.yaml",
"chars": 1851,
"preview": "model:\n base_learning_rate: 1.0e-04\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "environment.yaml",
"chars": 739,
"preview": "name: ldm\nchannels:\n - pytorch\n - defaults\ndependencies:\n - python=3.8.10\n - pip=20.3\n - cudatoolkit=11.3\n - pytor"
},
{
"path": "evaluation/clip_eval.py",
"chars": 4697,
"preview": "import clip\nimport torch\nfrom torchvision import transforms\n\nfrom ldm.models.diffusion.ddim import DDIMSampler\n\nclass CL"
},
{
"path": "ldm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/data/base.py",
"chars": 693,
"preview": "from abc import abstractmethod\nfrom torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset\n\n\nclas"
},
{
"path": "ldm/data/imagenet.py",
"chars": 15497,
"preview": "import os, yaml, pickle, shutil, tarfile, glob\nimport cv2\nimport albumentations\nimport PIL\nimport numpy as np\nimport tor"
},
{
"path": "ldm/data/lsun.py",
"chars": 3274,
"preview": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision impo"
},
{
"path": "ldm/data/personalized.py",
"chars": 7081,
"preview": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision impo"
},
{
"path": "ldm/data/personalized_style.py",
"chars": 4797,
"preview": "import os\nimport numpy as np\nimport PIL\nfrom PIL import Image\nfrom torch.utils.data import Dataset\nfrom torchvision impo"
},
{
"path": "ldm/lr_scheduler.py",
"chars": 3882,
"preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n \"\"\"\n note: use with a base_lr of 1.0\n \"\"\"\n def __in"
},
{
"path": "ldm/models/autoencoder.py",
"chars": 17619,
"preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom "
},
{
"path": "ldm/models/diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/models/diffusion/classifier.py",
"chars": 10276,
"preview": "import os\nimport torch\nimport pytorch_lightning as pl\nfrom omegaconf import OmegaConf\nfrom torch.nn import functional as"
},
{
"path": "ldm/models/diffusion/ddim.py",
"chars": 12797,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "ldm/models/diffusion/ddpm.py",
"chars": 72511,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "ldm/models/diffusion/plms.py",
"chars": 12450,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "ldm/modules/attention.py",
"chars": 8531,
"preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
},
{
"path": "ldm/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/diffusionmodules/model.py",
"chars": 33409,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom ein"
},
{
"path": "ldm/modules/diffusionmodules/openaimodel.py",
"chars": 34953,
"preview": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np"
},
{
"path": "ldm/modules/diffusionmodules/util.py",
"chars": 9633,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "ldm/modules/distributions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/distributions/distributions.py",
"chars": 2970,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "ldm/modules/ema.py",
"chars": 2982,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ldm/modules/embedding_manager.py",
"chars": 6630,
"preview": "import torch\nfrom torch import nn\n\nfrom ldm.data.personalized import per_img_token_list\nfrom transformers import CLIPTok"
},
{
"path": "ldm/modules/encoders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/encoders/modules.py",
"chars": 14801,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom t"
},
{
"path": "ldm/modules/encoders/modules_bak.py",
"chars": 18883,
"preview": "import torch\nimport torch.nn as nn\nfrom functools import partial\nimport clip\nfrom einops import rearrange, repeat\nfrom t"
},
{
"path": "ldm/modules/image_degradation/__init__.py",
"chars": 208,
"preview": "from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr\nfrom ldm.modules.image"
},
{
"path": "ldm/modules/image_degradation/bsrgan.py",
"chars": 25198,
"preview": "# -*- coding: utf-8 -*-\n\"\"\"\n# --------------------------------------------\n# Super-Resolution\n# ------------------------"
},
{
"path": "ldm/modules/image_degradation/bsrgan_light.py",
"chars": 22238,
"preview": "# -*- coding: utf-8 -*-\nimport numpy as np\nimport cv2\nimport torch\n\nfrom functools import partial\nimport random\nfrom sci"
},
{
"path": "ldm/modules/image_degradation/utils_image.py",
"chars": 29022,
"preview": "import os\nimport math\nimport random\nimport numpy as np\nimport torch\nimport cv2\nfrom torchvision.utils import make_grid\nf"
},
{
"path": "ldm/modules/losses/__init__.py",
"chars": 68,
"preview": "from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator"
},
{
"path": "ldm/modules/losses/contperceptual.py",
"chars": 5581,
"preview": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?\n"
},
{
"path": "ldm/modules/losses/vqperceptual.py",
"chars": 7941,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
},
{
"path": "ldm/modules/x_transformer.py",
"chars": 20369,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import "
},
{
"path": "ldm/util.py",
"chars": 5877,
"preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
},
{
"path": "main.py",
"chars": 32107,
"preview": "import argparse, os, sys, datetime, glob, importlib, csv\nimport numpy as np\nimport time\nimport torch\n\nimport torchvision"
},
{
"path": "merge_embeddings.py",
"chars": 3706,
"preview": "from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder\nfrom ldm.modules.embedding_manager import Embe"
},
{
"path": "models/first_stage_models/kl-f16/config.yaml",
"chars": 909,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/first_stage_models/kl-f32/config.yaml",
"chars": 929,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/first_stage_models/kl-f4/config.yaml",
"chars": 880,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/first_stage_models/kl-f8/config.yaml",
"chars": 889,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.AutoencoderKL\n params:\n monitor: val/rec_loss\n"
},
{
"path": "models/first_stage_models/vq-f16/config.yaml",
"chars": 1026,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 8\n n_embed: 16"
},
{
"path": "models/first_stage_models/vq-f4/config.yaml",
"chars": 955,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 3\n n_embed: 81"
},
{
"path": "models/first_stage_models/vq-f4-noattn/config.yaml",
"chars": 978,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 3\n n_embed: 81"
},
{
"path": "models/first_stage_models/vq-f8/config.yaml",
"chars": 1035,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 4\n n_embed: 16"
},
{
"path": "models/first_stage_models/vq-f8-n256/config.yaml",
"chars": 1013,
"preview": "model:\n base_learning_rate: 4.5e-06\n target: ldm.models.autoencoder.VQModel\n params:\n embed_dim: 4\n n_embed: 25"
},
{
"path": "models/ldm/bsr_sr/config.yaml",
"chars": 1900,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/celeba256/config.yaml",
"chars": 1599,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/cin256/config.yaml",
"chars": 1862,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/ffhq256/config.yaml",
"chars": 1591,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/inpainting_big/config.yaml",
"chars": 1619,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/layout2img-openimages256/config.yaml",
"chars": 1924,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/lsun_beds256/config.yaml",
"chars": 1601,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/lsun_churches256/config.yaml",
"chars": 2018,
"preview": "model:\n base_learning_rate: 5.0e-05\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/semantic_synthesis256/config.yaml",
"chars": 1378,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/semantic_synthesis512/config.yaml",
"chars": 1820,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "models/ldm/text2img256/config.yaml",
"chars": 1831,
"preview": "model:\n base_learning_rate: 2.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusion\n params:\n linear_start: 0.0"
},
{
"path": "scripts/download_first_stages.sh",
"chars": 1324,
"preview": "#!/bin/bash\nwget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip\nwge"
},
{
"path": "scripts/download_models.sh",
"chars": 1681,
"preview": "#!/bin/bash\nwget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip\nwget -O "
},
{
"path": "scripts/evaluate_model.py",
"chars": 2720,
"preview": "import argparse, os, sys, glob\n\nsys.path.append(os.path.join(sys.path[0], '..'))\n\nimport torch\nimport numpy as np\nfrom o"
},
{
"path": "scripts/inpaint.py",
"chars": 3644,
"preview": "import argparse, os, sys, glob\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfrom tqdm import tqdm\nimport numpy "
},
{
"path": "scripts/latent_imagenet_diffusion.ipynb",
"chars": 4172302,
"preview": "{\n \"nbformat\": 4,\n \"nbformat_minor\": 0,\n \"metadata\": {\n \"colab\": {\n \"name\": \"latent-imagenet-diffusion.ipynb\",\n \"pr"
},
{
"path": "scripts/sample_diffusion.py",
"chars": 9606,
"preview": "import argparse, os, sys, glob, datetime, yaml\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import trange\n\nfrom"
},
{
"path": "scripts/stable_txt2img.py",
"chars": 9425,
"preview": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfro"
},
{
"path": "scripts/txt2img.py",
"chars": 5660,
"preview": "import argparse, os, sys, glob\nimport torch\nimport numpy as np\nfrom omegaconf import OmegaConf\nfrom PIL import Image\nfro"
},
{
"path": "setup.py",
"chars": 233,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name='latent-diffusion',\n version='0.0.1',\n description=''"
}
]
About this extraction
This page contains the full source code of the XavierXiao/Dreambooth-Stable-Diffusion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 86 files (4.5 MB), approximately 1.2M tokens, and a symbol index with 724 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.