Showing preview only (444K chars total). Download the full file or copy to clipboard to get everything.
Repository: danier97/LDMVFI
Branch: main
Commit: eee2dc3566f2
Files: 47
Total size: 425.2 KB
Directory structure:
gitextract_grs9wzig/
├── .gitignore
├── LICENSE
├── README.md
├── configs/
│ ├── autoencoder/
│ │ └── vqflow-f32.yaml
│ └── ldm/
│ └── ldmvfi-vqflow-f32-c256-concat_max.yaml
├── cupy_module/
│ ├── __init__.py
│ └── dsepconv.py
├── environment.yaml
├── evaluate.py
├── evaluate_vqm.py
├── interpolate_yuv.py
├── ldm/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── bvi_vimeo.py
│ │ ├── testsets.py
│ │ ├── testsets_vqm.py
│ │ └── vfitransforms.py
│ ├── lr_scheduler.py
│ ├── models/
│ │ ├── autoencoder.py
│ │ └── diffusion/
│ │ ├── __init__.py
│ │ ├── ddim.py
│ │ └── ddpm.py
│ ├── modules/
│ │ ├── attention.py
│ │ ├── diffusionmodules/
│ │ │ ├── __init__.py
│ │ │ ├── model.py
│ │ │ ├── openaimodel.py
│ │ │ └── util.py
│ │ ├── ema.py
│ │ ├── losses/
│ │ │ ├── __init__.py
│ │ │ └── vqperceptual.py
│ │ └── maxvit.py
│ └── util.py
├── main.py
├── metrics/
│ ├── flolpips/
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── correlation/
│ │ │ └── correlation.py
│ │ ├── flolpips.py
│ │ ├── pretrained_networks.py
│ │ ├── pwcnet.py
│ │ └── utils.py
│ ├── lpips/
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ └── pretrained_networks.py
│ └── pytorch_ssim/
│ └── __init__.py
├── setup.py
└── utility.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pth
*.ckpt
*__pycache__*
*.pyc
*egg*
*src/*
*.ipynb
logs/*
*delete*
eval_results*
*.idea*
*.pytorch
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 danielism97
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# LDMVFI: Video Frame Interpolation with Latent Diffusion Models
[**Duolikun Danier**](https://danier97.github.io/), [**Fan Zhang**](https://fan-aaron-zhang.github.io/), [**David Bull**](https://david-bull.github.io/)
[Project](TODO) | [arXiv](https://arxiv.org/abs/2303.09508) | [Video](https://drive.google.com/file/d/1oL6j_l3b2QEqsL0iO7qSZrGUXJaTpRWN/view?usp=share_link)

## Overview
We observe that most existing learning-based VFI models are trained to minimise the L1/L2/VGG loss between their outputs and the ground-truth frames. However, it was shown in previous works that these metrics do not correlate well with the **perceptual quality** of VFI. On the other hand, generative models, especially diffusion models, are showing remarkable results in generating visual content with high perceptual quality. In this work, we leverage the high-fidelity image/video generation capabilities of **latent diffusion models** to perform generative VFI.
<p align="center">
<img src="https://danier97.github.io/LDMVFI/overall.svg" alt="Paper" width="60%">
</p>
## Dependencies and Installation
See [environment.yaml](./environment.yaml) for requirements on packages. Simple installation:
```
conda env create -f environment.yaml
```
## Pre-trained Model
The pre-trained model can be downloaded from [here](https://drive.google.com/file/d/1_Xx2fBYQT9O-6O3zjzX76O9XduGnCh_7/view?usp=share_link), and its corresponding config file is [this yaml](./configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml).
## Preparing datasets
### Training sets:
[[Vimeo-90K]](http://toflow.csail.mit.edu/) | [[BVI-DVC quintuplets]](https://drive.google.com/file/d/1i_CoqiNrZ2AU8DKjU8aHM1jIaDGW0fE5/view?usp=sharing)
### Test sets:
[[Middlebury]](https://vision.middlebury.edu/flow/data/) | [[UCF101]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[DAVIS]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[SNU-FILM]](https://myungsub.github.io/CAIN/)
To make use of the [evaluate.py](evaluate.py) and the files in [ldm/data/](./ldm/data/), the dataset folder names should be lower-case and structured as follows.
```
└──── <data directory>/
├──── middlebury_others/
| ├──── input/
| | ├──── Beanbags/
| | ├──── ...
| | └──── Walking/
| └──── gt/
| ├──── Beanbags/
| ├──── ...
| └──── Walking/
├──── ucf101/
| ├──── 0/
| ├──── ...
| └──── 99/
├──── davis90/
| ├──── bear/
| ├──── ...
| └──── walking/
├──── snufilm/
| ├──── test-easy.txt
| ├──── ...
| └──── data/SNU-FILM/test/...
├──── bvidvc/quintuplets
| ├──── 00000/
| ├──── ...
| └──── 17599/
└──── vimeo_septuplet/
├──── sequences/
├──── sep_testlist.txt
└──── sep_trainlist.txt
```
## Evaluation
To evaluate LDMVFI (with DDIM sampler), for example, on the Middlebury dataset, using PSNR/SSIM/LPIPS, run the following command.
```
python evaluate.py \
--config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \
--ckpt <path/to/ldmvfi-vqflow-f32-c256-concat_max.ckpt> \
--dataset Middlebury_others \
--metrics PSNR SSIM LPIPS \
--data_dir <path/to/data/dir> \
--out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \
--use_ddim
```
This will create the directory `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/`, and store the interpolated frames, as well as a `results.txt` file in that directory. For other test sets, replace `Middlebury_other` with the corresponding class names defined in [ldm/data/testsets.py](ldm/data/testsets.py) (e.g. `Ucf101_triplet`).
\
To evaluate the model on perceptual video metric FloLPIPS, first evaluate the image metrics using the code above (so that the interpolated frames are saved in `eval_results/ldmvfi-vqflow-f32-c256-concat_max`), then run the following code.
```
python evaluate_vqm.py \
--exp ldmvfi-vqflow-f32-c256-concat_max \
--dataset Middlebury_others \
--metrics FloLPIPS \
--data_dir <path/to/data/dir> \
--out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \
```
This will read the interpolated frames previously stored in `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/` then output the evaluation results to `results_vqm.txt` in the same folder.
\
To interpolate a video (in .yuv format), use the following code.
```
python interpolate_yuv.py \
--net LDMVFI \
--config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \
--ckpt <path/to/ldmvfi-vqflow-f32-c256-concat_max.ckpt> \
--input_yuv <path/to/input/yuv> \
--size <spatial res of video, e.g. 1920x1080> \
--out_fps <output fps, should be 2 x original fps> \
--out_dir <desired/output/dir> \
--use_ddim
```
## Training
LDMVFI is trained in two stages, where the VQ-FIGAN and the denoising U-Net are trained separately.
### VQ-FIGAN
```
python main.py --base configs/autoencoder/vqflow-f32.yaml -t --gpus 0,
```
### Denoising U-Net
```
python main.py --base configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml -t --gpus 0,
```
These will create a `logs/` folder within which the corresonding directories are created for each experiment. The log files from training include checkpoints, images and tensorboard loggings.
To resume from a checkpoint file, simply use the `--resume` argument in [main.py](main.py) to specify the checkpoint.
## Citation
```
@article{danier2023ldmvfi,
title={LDMVFI: Video Frame Interpolation with Latent Diffusion Models},
author={Danier, Duolikun and Zhang, Fan and Bull, David},
journal={arXiv preprint arXiv:2303.09508},
year={2023}
}
```
## Acknowledgement
Our code is adapted from the original [latent-diffusion](https://github.com/CompVis/latent-diffusion) repository. We thank the authors for sharing their code.
================================================
FILE: configs/autoencoder/vqflow-f32.yaml
================================================
model:
base_learning_rate: 1.0e-5
target: ldm.models.autoencoder.VQFlowNet
params:
monitor: val/total_loss
embed_dim: 3
n_embed: 8192
ddconfig:
double_z: False
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 64
ch_mult: [1,2,2,2,4] # f = 2 ^ len(ch_mult)
num_res_blocks: 1
cond_type: max_cross_attn
attn_type: max
attn_resolutions: []
dropout: 0.0
lossconfig:
target: ldm.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 10000
disc_weight: 0.8
codebook_weight: 1.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 10
num_workers: 0
wrap: false
train:
target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet
params:
db_dir: C:/data_tmp/
crop_sz: [256,256]
iter: True
validation:
target: ldm.data.bvi_vimeo.Vimeo90k_triplet
params:
db_dir: C:/data_tmp/vimeo_septuplet/
train: False
crop_sz: [256,256]
augment_s: False
augment_t: False
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 8000
val_batch_frequency: 800
max_images: 8
increase_log_steps: False
log_images_kwargs: {'N': 1}
trainer:
benchmark: True
max_epochs: -1
================================================
FILE: configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml
================================================
model:
base_learning_rate: 1.0e-06
target: ldm.models.diffusion.ddpm.LatentDiffusionVFI
params:
linear_start: 0.0015
linear_end: 0.0195
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: past_future_frames
image_size: 8
channels: 3
cond_stage_trainable: False
concat_mode: True
monitor: val/loss_simple_ema
unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
params:
image_size: 8 # img size of latent, used during training, determines some model params, so don't change for inference
in_channels: 9
out_channels: 3
model_channels: 256
attention_resolutions:
#note: this isn\t actually the resolution but
# the downsampling factor, i.e. this corresnponds to
# attention on spatial resolution 8,16,32, as the
# spatial reolution of the latents is 32 for f8
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
num_head_channels: 32
use_max_self_attn: True # replace all full self-attention with MaxViT
first_stage_config:
target: ldm.models.autoencoder.VQFlowNetInterface
params:
ckpt_path: null # must specify pre-trained autoencoding model ckpt to train the denoising UNet
embed_dim: 3
n_embed: 8192
ddconfig:
double_z: False
z_channels: 3
resolution: 256
in_channels: 3
out_ch: 3
ch: 64
ch_mult: [1,2,2,2,4] # f = 2 ^ len(ch_mult)
num_res_blocks: 1
cond_type: max_cross_attn
attn_type: max
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config: __is_first_stage__
data:
target: main.DataModuleFromConfig
params:
batch_size: 64
num_workers: 0
wrap: false
train:
target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet
params:
db_dir: C:/data_tmp/
crop_sz: [256,256]
iter: True
validation:
target: ldm.data.bvi_vimeo.Vimeo90k_triplet
params:
db_dir: C:/data_tmp/vimeo_septuplet/
train: False
crop_sz: [256,256]
augment_s: False
augment_t: False
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1250
val_batch_frequency: 125
max_images: 8
increase_log_steps: False
log_images_kwargs: {'N': 1}
trainer:
benchmark: True
max_epochs: -1
================================================
FILE: cupy_module/__init__.py
================================================
================================================
FILE: cupy_module/dsepconv.py
================================================
import torch
import cupy
import re
class Stream:
ptr = torch.cuda.current_stream().cuda_stream
# end
kernel_DSepconv_updateOutput = '''
extern "C" __global__ void kernel_DSepconv_updateOutput(
const int n,
const float* input,
const float* vertical,
const float* horizontal,
const float* offset_x,
const float* offset_y,
const float* mask,
float* output
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float dblOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
const int intDepth = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output);
const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output);
const int intX = ( intIndex ) % SIZE_3(output);
for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1) {
for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1) {
float delta_x = OFFSET_4(offset_y, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, intDepth, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, intDepth, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, intDepth, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
dblOutput += floatValue * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, SIZE_1(vertical)*intFilterY + intFilterX, intY, intX);
}
}
output[intIndex] = dblOutput;
} }
'''
kernel_DSepconv_updateGradVertical = '''
extern "C" __global__ void kernel_DSepconv_updateGradVertical(
const int n,
const float* gradLoss,
const float* input,
const float* horizontal,
const float* offset_x,
const float* offset_y,
const float* mask,
float* gradVertical
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float floatOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical) / SIZE_1(gradVertical) ) % SIZE_0(gradVertical);
const int intFilterY = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical) ) % SIZE_1(gradVertical);
const int intY = ( intIndex / SIZE_3(gradVertical) ) % SIZE_2(gradVertical);
const int intX = ( intIndex ) % SIZE_3(gradVertical);
for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1){
int intDepth = intFilterY * SIZE_1(horizontal) + intFilterX;
float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX);
}
gradVertical[intIndex] = floatOutput;
} }
'''
kernel_DSepconv_updateGradHorizontal = '''
extern "C" __global__ void kernel_DSepconv_updateGradHorizontal(
const int n,
const float* gradLoss,
const float* input,
const float* vertical,
const float* offset_x,
const float* offset_y,
const float* mask,
float* gradHorizontal
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float floatOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal) / SIZE_1(gradHorizontal) ) % SIZE_0(gradHorizontal);
const int intFilterX = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal) ) % SIZE_1(gradHorizontal);
const int intY = ( intIndex / SIZE_3(gradHorizontal) ) % SIZE_2(gradHorizontal);
const int intX = ( intIndex ) % SIZE_3(gradHorizontal);
for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1){
int intDepth = intFilterY * SIZE_1(vertical) + intFilterX;
float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX);
}
gradHorizontal[intIndex] = floatOutput;
} }
'''
kernel_DSepconv_updateGradMask = '''
extern "C" __global__ void kernel_DSepconv_updateGradMask(
const int n,
const float* gradLoss,
const float* input,
const float* vertical,
const float* horizontal,
const float* offset_x,
const float* offset_y,
float* gradMask
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float floatOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask) / SIZE_1(gradMask) ) % SIZE_0(gradMask);
const int intDepth = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask) ) % SIZE_1(gradMask);
const int intY = ( intIndex / SIZE_3(gradMask) ) % SIZE_2(gradMask);
const int intX = ( intIndex ) % SIZE_3(gradMask);
int intFilterY = intDepth / SIZE_1(vertical);
int intFilterX = intDepth % SIZE_1(vertical);
float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
for (int intChannel = 0; intChannel < 3; intChannel++){
floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
VALUE_4(input, intSample, intChannel, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, intChannel, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) +
VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) +
VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y))
) * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX);
}
gradMask[intIndex] = floatOutput;
} }
'''
kernel_DSepconv_updateGradOffsetX = '''
extern "C" __global__ void kernel_DSepconv_updateGradOffsetX(
const int n,
const float* gradLoss,
const float* input,
const float* vertical,
const float* horizontal,
const float* offset_x,
const float* offset_y,
const float* mask,
float* gradOffsetX
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float floatOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX);
const int intDepth = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) ) % SIZE_1(gradOffsetX);
const int intY = ( intIndex / SIZE_3(gradOffsetX) ) % SIZE_2(gradOffsetX);
const int intX = ( intIndex ) % SIZE_3(gradOffsetX);
int intFilterY = intDepth / SIZE_1(vertical);
int intFilterX = intDepth % SIZE_1(vertical);
float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
for (int intChannel = 0; intChannel < 3; intChannel++){
floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
- VALUE_4(input, intSample, intChannel, top, left) * (1 + (left - position_x))
- VALUE_4(input, intSample, intChannel, top, right) * (1 - (right - position_x))
+ VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x))
+ VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x))
)
* VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX)
* VALUE_4(mask, intSample, intDepth, intY, intX);
}
gradOffsetX[intIndex] = floatOutput;
} }
'''
kernel_DSepconv_updateGradOffsetY = '''
extern "C" __global__ void kernel_DSepconv_updateGradOffsetY(
const int n,
const float* gradLoss,
const float* input,
const float* vertical,
const float* horizontal,
const float* offset_x,
const float* offset_y,
const float* mask,
float* gradOffsetY
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
float floatOutput = 0.0;
const int intSample = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX);
const int intDepth = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) ) % SIZE_1(gradOffsetX);
const int intY = ( intIndex / SIZE_3(gradOffsetX) ) % SIZE_2(gradOffsetX);
const int intX = ( intIndex ) % SIZE_3(gradOffsetX);
int intFilterY = intDepth / SIZE_1(vertical);
int intFilterX = intDepth % SIZE_1(vertical);
float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
if (position_x < 0)
position_x = 0;
if (position_x > SIZE_3(input) - 1)
position_x = SIZE_3(input) - 1;
if (position_y < 0)
position_y = 0;
if (position_y > SIZE_2(input) - 1)
position_y = SIZE_2(input) - 1;
int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
int right = left + 1;
if (left < 0)
left = 0;
if (left > SIZE_3(input) - 1)
left = SIZE_3(input) - 1;
if (right < 0)
right = 0;
if (right > SIZE_3(input) - 1)
right = SIZE_3(input) - 1;
int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
int bottom = top + 1;
if (top < 0)
top = 0;
if (top > SIZE_2(input) - 1)
top = SIZE_2(input) - 1;
if (bottom < 0)
bottom = 0;
if (bottom > SIZE_2(input) - 1)
bottom = SIZE_2(input) - 1;
for (int intChannel = 0; intChannel < 3; intChannel++){
floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
- VALUE_4(input, intSample, intChannel, top, left) * (1 + (top - position_y))
+ VALUE_4(input, intSample, intChannel, top, right) * (1 + (top - position_y))
- VALUE_4(input, intSample, intChannel, bottom, left) * (1 - (bottom - position_y))
+ VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (bottom - position_y))
)
* VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX)
* VALUE_4(mask, intSample, intDepth, intY, intX);
}
gradOffsetY[intIndex] = floatOutput;
} }
'''
def cupy_kernel(strFunction, objectVariables):
strKernel = globals()[strFunction]
while True:
objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
if objectMatch is None:
break
# end
intArg = int(objectMatch.group(2))
strTensor = objectMatch.group(4)
intSizes = objectVariables[strTensor].size()
strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
# end
while True:
objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
if objectMatch is None:
break
# end
intArgs = int(objectMatch.group(2))
strArgs = objectMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objectVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
# end
while True:
objectMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)
if objectMatch is None:
break
# end
intArgs = int(objectMatch.group(2))
strArgs = objectMatch.group(4).split(',')
strTensor = strArgs[0]
intStrides = objectVariables[strTensor].stride()
strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
intStrides[intArg]) + ')' for intArg in range(intArgs)]
strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
# end
return strKernel
# end
@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
# return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
return cupy.RawKernel(strKernel, strFunction)
# end
class _FunctionDSepconv(torch.autograd.Function):
@staticmethod
def forward(self, input, vertical, horizontal, offset_x, offset_y, mask):
self.save_for_backward(input, vertical, horizontal, offset_x, offset_y, mask)
intSample = input.size(0)
intInputDepth = input.size(1)
intInputHeight = input.size(2)
intInputWidth = input.size(3)
intFilterSize = min(vertical.size(1), horizontal.size(1))
intOutputHeight = min(vertical.size(2), horizontal.size(2))
intOutputWidth = min(vertical.size(3), horizontal.size(3))
assert (intInputHeight == intOutputHeight + intFilterSize - 1)
assert (intInputWidth == intOutputWidth + intFilterSize - 1)
assert (input.is_contiguous() == True)
assert (vertical.is_contiguous() == True)
assert (horizontal.is_contiguous() == True)
assert (offset_x.is_contiguous() == True)
assert (offset_y.is_contiguous() == True)
assert (mask.is_contiguous() == True)
output = input.new_zeros([intSample, intInputDepth, intOutputHeight, intOutputWidth])
if input.is_cuda == True:
n = output.nelement()
cupy_launch('kernel_DSepconv_updateOutput', cupy_kernel('kernel_DSepconv_updateOutput', {
'input': input,
'vertical': vertical,
'horizontal': horizontal,
'offset_x': offset_x,
'offset_y': offset_y,
'mask': mask,
'output': output
}))(
grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[n, input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(),
mask.data_ptr(), output.data_ptr()],
stream=Stream
)
elif input.is_cuda == False:
raise NotImplementedError()
# end
return output
# end
@staticmethod
def backward(self, gradOutput):
input, vertical, horizontal, offset_x, offset_y, mask = self.saved_tensors
intSample = input.size(0)
intInputDepth = input.size(1)
intInputHeight = input.size(2)
intInputWidth = input.size(3)
intFilterSize = min(vertical.size(1), horizontal.size(1))
intOutputHeight = min(vertical.size(2), horizontal.size(2))
intOutputWidth = min(vertical.size(3), horizontal.size(3))
assert (intInputHeight == intOutputHeight + intFilterSize - 1)
assert (intInputWidth == intOutputWidth + intFilterSize - 1)
assert (gradOutput.is_contiguous() == True)
gradInput = input.new_zeros([intSample, intInputDepth, intInputHeight, intInputWidth]) if \
self.needs_input_grad[0] == True else None
gradVertical = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \
self.needs_input_grad[1] == True else None
gradHorizontal = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \
self.needs_input_grad[2] == True else None
gradOffsetX = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
self.needs_input_grad[3] == True else None
gradOffsetY = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
self.needs_input_grad[4] == True else None
gradMask = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
self.needs_input_grad[5] == True else None
if input.is_cuda == True:
nv = gradVertical.nelement()
cupy_launch('kernel_DSepconv_updateGradVertical', cupy_kernel('kernel_DSepconv_updateGradVertical', {
'gradLoss': gradOutput,
'input': input,
'horizontal': horizontal,
'offset_x': offset_x,
'offset_y': offset_y,
'mask': mask,
'gradVertical': gradVertical
}))(
grid=tuple([int((nv + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[nv, gradOutput.data_ptr(), input.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(),
offset_y.data_ptr(), mask.data_ptr(), gradVertical.data_ptr()],
stream=Stream
)
nh = gradHorizontal.nelement()
cupy_launch('kernel_DSepconv_updateGradHorizontal', cupy_kernel('kernel_DSepconv_updateGradHorizontal', {
'gradLoss': gradOutput,
'input': input,
'vertical': vertical,
'offset_x': offset_x,
'offset_y': offset_y,
'mask': mask,
'gradHorizontal': gradHorizontal
}))(
grid=tuple([int((nh + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[nh, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), offset_x.data_ptr(),
offset_y.data_ptr(), mask.data_ptr(), gradHorizontal.data_ptr()],
stream=Stream
)
nx = gradOffsetX.nelement()
cupy_launch('kernel_DSepconv_updateGradOffsetX', cupy_kernel('kernel_DSepconv_updateGradOffsetX', {
'gradLoss': gradOutput,
'input': input,
'vertical': vertical,
'horizontal': horizontal,
'offset_x': offset_x,
'offset_y': offset_y,
'mask': mask,
'gradOffsetX': gradOffsetX
}))(
grid=tuple([int((nx + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[nx, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(),
offset_y.data_ptr(), mask.data_ptr(), gradOffsetX.data_ptr()],
stream=Stream
)
ny = gradOffsetY.nelement()
cupy_launch('kernel_DSepconv_updateGradOffsetY', cupy_kernel('kernel_DSepconv_updateGradOffsetY', {
'gradLoss': gradOutput,
'input': input,
'vertical': vertical,
'horizontal': horizontal,
'offset_x': offset_x,
'offset_y': offset_y,
'mask': mask,
'gradOffsetX': gradOffsetY
}))(
grid=tuple([int((ny + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[ny, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(),
offset_x.data_ptr(),
offset_y.data_ptr(), mask.data_ptr(), gradOffsetY.data_ptr()],
stream=Stream
)
nm = gradMask.nelement()
cupy_launch('kernel_DSepconv_updateGradMask', cupy_kernel('kernel_DSepconv_updateGradMask', {
'gradLoss': gradOutput,
'input': input,
'vertical': vertical,
'horizontal': horizontal,
'offset_x': offset_x,
'offset_y': offset_y,
'gradMask': gradMask
}))(
grid=tuple([int((nm + 512 - 1) / 512), 1, 1]),
block=tuple([512, 1, 1]),
args=[nm, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(),
offset_x.data_ptr(),
offset_y.data_ptr(), gradMask.data_ptr()],
stream=Stream
)
elif input.is_cuda == False:
raise NotImplementedError()
# end
return gradInput, gradVertical, gradHorizontal, gradOffsetX, gradOffsetY, gradMask
# end
# end
def FunctionDSepconv(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask):
return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask)
# end
class ModuleDSepconv(torch.nn.Module):
def __init__(self):
super(ModuleDSepconv, self).__init__()
# end
def forward(self, tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask):
return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask)
# end
# end
# float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 - (delta_x - floor(delta_x))) * (1 - (delta_y - floor(delta_y))) +
# VALUE_4(input, intSample, intDepth, top, right) * (delta_x - floor(delta_x)) * (1 - (delta_y - floor(delta_y))) +
# VALUE_4(input, intSample, intDepth, bottom, left) * (1 - (delta_x - floor(delta_x))) * (delta_y - floor(delta_y)) +
# VALUE_4(input, intSample, intDepth, bottom, right) * (delta_x - floor(delta_x)) * (delta_y - floor(delta_y));
================================================
FILE: environment.yaml
================================================
name: ldmvfi
channels:
- pytorch
- defaults
- conda-forge
dependencies:
- python=3.9.13
- pytorch=1.11.0
- torchvision=0.12.0
- cudatoolkit=11.3
- pip:
- opencv-python==4.6.0.66
- pudb==2022.1.3
- imageio==2.22.3
- imageio-ffmpeg==0.4.7
- pytorch-lightning==1.7.7
- omegaconf==2.2.3
- test-tube==0.7.5
- streamlit==1.14.0
- einops==0.5.0
- torch-fidelity==0.3.0
- transformers==4.23.1
- timm==0.6.12
- cupy
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e .
# conda create -n ldmvfi python=3.9
# conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
# pip install opencv-python==4.6.0.66 pudb==2022.1.3 imageio==2.22.3 imageio-ffmpeg==0.4.7 pytorch-lightning==1.7.7 omegaconf==2.2.3 test-tube==0.7.5 streamlit==1.14.0 einops==0.5.0 torch-fidelity==0.3.0 transformers==4.23.1
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
# pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
# pip install -e .
# pip install timm
================================================
FILE: evaluate.py
================================================
import argparse
import os
import torch
from functools import partial
from omegaconf import OmegaConf
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.data import testsets
parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')
parser.add_argument('--config', type=str, default=None)
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Middlebury_others')
parser.add_argument('--metrics', nargs='+', type=str, default=['PSNR', 'SSIM', 'LPIPS'])
parser.add_argument('--data_dir', type=str, default='D:\\')
parser.add_argument('--out_dir', type=str, default='eval_results')
parser.add_argument('--resume', dest='resume', default=False, action='store_true')
# sampler args
parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true')
parser.add_argument('--ddim_eta', type=float, default=1.0)
parser.add_argument('--ddim_steps', type=int, default=200)
def main():
args = parser.parse_args()
# initialise model
config = OmegaConf.load(args.config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(args.ckpt)['state_dict'])
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model = model.eval()
print('Model loaded successfully')
# set up sampler
if args.use_ddim:
ddim = DDIMSampler(model)
sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False)
else:
sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False)
# setup output dirs
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
# initialise test set
print('Testing on dataset: ', args.dataset)
test_dir = os.path.join(args.out_dir, args.dataset)
if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']:
db_folder = args.dataset.split('_')[0].lower()
else:
db_folder = args.dataset.lower()
test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder))
if not os.path.exists(test_dir):
os.mkdir(test_dir)
test_db.eval(model, sample_func, metrics=args.metrics, output_dir=test_dir, resume=args.resume)
if __name__ == '__main__':
main()
================================================
FILE: evaluate_vqm.py
================================================
import argparse
import os
from ldm.data import testsets_vqm
parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')
parser.add_argument('--exp', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Middlebury_others')
parser.add_argument('--metrics', nargs='+', type=str, default=['FloLPIPS'])
parser.add_argument('--data_dir', type=str, default='D:\\')
parser.add_argument('--out_dir', type=str, default='eval_results')
parser.add_argument('--resume', dest='resume', default=False, action='store_true')
def main():
args = parser.parse_args()
# initialise model
model = args.exp
print('Evaluating model:', model)
# setup output dirs
assert os.path.exists(args.out_dir), 'Frames not previously interpolated!'
# initialise test set
print('Testing on dataset: ', args.dataset)
test_dir = os.path.join(args.out_dir, args.dataset)
assert os.path.exists(test_dir), f'{args.dataset} not pre-computed!'
if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']:
db_folder = args.dataset.split('_')[0].lower()
else:
db_folder = args.dataset.lower()
test_db = getattr(testsets_vqm, args.dataset)(os.path.join(args.data_dir, db_folder))
test_db.eval(metrics=args.metrics, output_dir=test_dir, resume=args.resume)
if __name__ == '__main__':
main()
================================================
FILE: interpolate_yuv.py
================================================
import argparse
import torch
import torchvision.transforms.functional as TF
import os
from PIL import Image
from tqdm import tqdm
import skvideo.io
from functools import partial
from utility import read_frame_yuv2rgb, tensor2rgb
from omegaconf import OmegaConf
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')
parser.add_argument('--net', type=str, default='LDMVFI')
parser.add_argument('--config', type=str, default='configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml')
parser.add_argument('--ckpt', type=str, default='ckpt.pth')
parser.add_argument('--input_yuv', type=str, default='D:\\')
parser.add_argument('--size', type=str, default='1920x1080')
parser.add_argument('--out_fps', type=int, default=60)
parser.add_argument('--out_dir', type=str, default='.')
# sampler args
parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true')
parser.add_argument('--ddim_eta', type=float, default=1.0)
parser.add_argument('--ddim_steps', type=int, default=200)
def main():
args = parser.parse_args()
# initialise model
config = OmegaConf.load(args.config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(args.ckpt)['state_dict'])
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model = model.eval()
print('Model loaded successfully')
# set up sampler
if args.use_ddim:
ddim = DDIMSampler(model)
sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False)
else:
sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False)
# Setup output file
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
_, fname = os.path.split(args.input_yuv)
seq_name = fname.strip('.yuv')
width, height = args.size.split('x')
bit_depth = 16 if '16bit' in fname else 10 if '10bit' in fname else 8
pix_fmt = '444' if '444' in fname else '420'
try:
width = int(width)
height = int(height)
except:
print('Invalid size, should be \'<width>x<height>\'')
return
outname = '{}_{}x{}_{}fps_{}.mp4'.format(seq_name, width, height, args.out_fps, args.net)
writer = skvideo.io.FFmpegWriter(os.path.join(args.out_dir, outname),
inputdict={
'-r': str(args.out_fps)
},
outputdict={
'-pix_fmt': 'yuv420p',
'-s': '{}x{}'.format(width,height),
'-r': str(args.out_fps),
'-vcodec': 'libx264', #use the h.264 codec
'-crf': '0', #set the constant rate factor to 0, which is lossless
'-preset':'veryslow' #the slower the better compression, in princple, try
#other options see https://trac.ffmpeg.org/wiki/Encode/H.264
}
)
# Start interpolation
print('Using model {} to upsample file {}'.format(args.net, fname))
stream = open(args.input_yuv, 'r')
file_size = os.path.getsize(args.input_yuv)
# YUV reading setup
bytes_per_frame = width*height*1.5
if pix_fmt == '444':
bytes_per_frame *= 2
if bit_depth != 8:
bytes_per_frame *= 2
num_frames = int(file_size // bytes_per_frame)
rawFrame0 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, 0, bit_depth, pix_fmt))
frame0 = TF.normalize(TF.to_tensor(rawFrame0), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()
for t in tqdm(range(1, num_frames)):
rawFrame1 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, t, bit_depth, pix_fmt))
frame1 = TF.normalize(TF.to_tensor(rawFrame1), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()
with torch.no_grad():
with model.ema_scope():
# form condition tensor and define shape of latent rep
xc = {'prev_frame': frame0, 'next_frame': frame1}
c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
shape = (model.channels, c.shape[2], c.shape[3])
# run sampling and get denoised latent
out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape)
if isinstance(out, tuple): # using ddim
out = out[0]
# reconstruct interpolated frame from latent
out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]
# write to output video
writer.writeFrame(tensor2rgb(frame0)[0])
writer.writeFrame(tensor2rgb(out)[0])
# update frame0
frame0 = frame1
# write the last frame
writer.writeFrame(tensor2rgb(frame1)[0])
stream.close()
writer.close() # close the writer
if __name__ == "__main__":
main()
================================================
FILE: ldm/data/__init__.py
================================================
================================================
FILE: ldm/data/bvi_vimeo.py
================================================
import numpy as np
import random
from os import listdir
from os.path import join, isdir, split, getsize
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from PIL import Image
import ldm.data.vfitransforms as vt
from functools import partial
class Vimeo90k_triplet(Dataset):
def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=True, augment_t=True):
seq_dir = join(db_dir, 'sequences')
self.crop_sz = crop_sz
self.augment_s = augment_s
self.augment_t = augment_t
if train:
seq_list_txt = join(db_dir, 'sep_trainlist.txt')
else:
seq_list_txt = join(db_dir, 'sep_testlist.txt')
with open(seq_list_txt) as f:
contents = f.readlines()
seq_path = [line.strip() for line in contents if line != '\n']
self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path]
def __getitem__(self, index):
rawFrame3 = Image.open(join(self.seq_path_list[index], "im3.png"))
rawFrame4 = Image.open(join(self.seq_path_list[index], "im4.png"))
rawFrame5 = Image.open(join(self.seq_path_list[index], "im5.png"))
if self.crop_sz is not None:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz)
if self.augment_s:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5)
if self.augment_t:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5)
to_array = partial(np.array, dtype=np.float32)
frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255
frame3 = frame3/127.5 - 1.0
frame4 = frame4/127.5 - 1.0
frame5 = frame5/127.5 - 1.0
return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5}
def __len__(self):
return len(self.seq_path_list)
class Vimeo90k_quintuplet(Dataset):
def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=True, augment_t=True):
seq_dir = join(db_dir, 'sequences')
self.crop_sz = crop_sz
self.augment_s = augment_s
self.augment_t = augment_t
if train:
seq_list_txt = join(db_dir, 'sep_trainlist.txt')
else:
seq_list_txt = join(db_dir, 'sep_testlist.txt')
with open(seq_list_txt) as f:
contents = f.readlines()
seq_path = [line.strip() for line in contents if line != '\n']
self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path]
def __getitem__(self, index):
rawFrame1 = Image.open(join(self.seq_path_list[index], "im1.png"))
rawFrame3 = Image.open(join(self.seq_path_list[index], "im3.png"))
rawFrame4 = Image.open(join(self.seq_path_list[index], "im4.png"))
rawFrame5 = Image.open(join(self.seq_path_list[index], "im5.png"))
rawFrame7 = Image.open(join(self.seq_path_list[index], "im7.png"))
if self.crop_sz is not None:
rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_crop(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, sz=self.crop_sz)
if self.augment_s:
rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
if self.augment_t:
rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7))
return frame1, frame3, frame4, frame5, frame7
def __len__(self):
return len(self.seq_path_list)
class BVIDVC_triplet(Dataset):
def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True):
db_dir = join(db_dir, 'quintuplets')
self.crop_sz = crop_sz
self.augment_s = augment_s
self.augment_t = augment_t
self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)]
def __getitem__(self, index):
cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png'))
rawFrame3 = cat.crop((256, 0, 256*2, 256))
rawFrame5 = cat.crop((256*2, 0, 256*3, 256))
rawFrame4 = cat.crop((256*4, 0, 256*5, 256))
if self.crop_sz is not None:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz)
if self.augment_s:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5)
if self.augment_t:
rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5)
to_array = partial(np.array, dtype=np.float32)
frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255
frame3 = frame3/127.5 - 1.0
frame4 = frame4/127.5 - 1.0
frame5 = frame5/127.5 - 1.0
return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5}
def __len__(self):
return len(self.seq_path_list)
class BVIDVC_quintuplet(Dataset):
def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True):
db_dir = join(db_dir, 'quintuplets')
self.crop_sz = crop_sz
self.augment_s = augment_s
self.augment_t = augment_t
self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)]
def __getitem__(self, index):
cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png'))
rawFrame1 = cat.crop((0, 0, 256, 256))
rawFrame3 = cat.crop((256, 0, 256*2, 256))
rawFrame5 = cat.crop((256*2, 0, 256*3, 256))
rawFrame7 = cat.crop((256*3, 0, 256*4, 256))
rawFrame4 = cat.crop((256*4, 0, 256*5, 256))
if self.augment_s:
rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
if self.augment_t:
rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7))
return frame1, frame3, frame4, frame5, frame7
def __len__(self):
return len(self.seq_path_list)
class Sampler(Dataset):
def __init__(self, datasets, p_datasets=None, iter=False, samples_per_epoch=1000):
self.datasets = datasets
self.len_datasets = np.array([len(dataset) for dataset in self.datasets])
self.p_datasets = p_datasets
self.iter = iter
if p_datasets is None:
self.p_datasets = self.len_datasets / np.sum(self.len_datasets)
self.samples_per_epoch = samples_per_epoch
self.accum = [0,]
for i, length in enumerate(self.len_datasets):
self.accum.append(self.accum[-1] + self.len_datasets[i])
def __getitem__(self, index):
if self.iter:
# iterate through all datasets
for i in range(len(self.accum)):
if index < self.accum[i]:
return self.datasets[i-1].__getitem__(index-self.accum[i-1])
else:
# first sample a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
# sample a sequence from the dataset
return dataset.__getitem__(random.randint(0,len(dataset)-1))
def __len__(self):
if self.iter:
return int(np.sum(self.len_datasets))
else:
return self.samples_per_epoch
class BVI_Vimeo_triplet(Dataset):
def __init__(self, db_dir, crop_sz=[256,256], p_datasets=None, iter=False, samples_per_epoch=1000):
vimeo90k_train = Vimeo90k_triplet(join(db_dir, 'vimeo_septuplet'), train=True, crop_sz=crop_sz)
bvidvc_train = BVIDVC_triplet(join(db_dir, 'bvidvc'), crop_sz=crop_sz)
self.datasets = [vimeo90k_train, bvidvc_train]
self.len_datasets = np.array([len(dataset) for dataset in self.datasets])
self.p_datasets = p_datasets
self.iter = iter
if p_datasets is None:
self.p_datasets = self.len_datasets / np.sum(self.len_datasets)
self.samples_per_epoch = samples_per_epoch
self.accum = [0,]
for i, length in enumerate(self.len_datasets):
self.accum.append(self.accum[-1] + self.len_datasets[i])
def __getitem__(self, index):
if self.iter:
# iterate through all datasets
for i in range(len(self.accum)):
if index < self.accum[i]:
return self.datasets[i-1].__getitem__(index-self.accum[i-1])
else:
# first sample a dataset
dataset = random.choices(self.datasets, self.p_datasets)[0]
# sample a sequence from the dataset
return dataset.__getitem__(random.randint(0,len(dataset)-1))
def __len__(self):
if self.iter:
return int(np.sum(self.len_datasets))
else:
return self.samples_per_epoch
================================================
FILE: ldm/data/testsets.py
================================================
import glob
from typing import List
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.utils import save_image as imwrite
import os
from os.path import join, exists
import utility
import numpy as np
import ast
import time
from ldm.models.autoencoder import *
class TripletTestSet:
def __init__(self):
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False):
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results.txt'), 'a')
for idx in range(len(self.im_list)):
if resume and idx < start_idx:
assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {self.im_list[idx]}')
t0 = time.time()
if not exists(join(output_dir, self.im_list[idx])):
os.makedirs(join(output_dir, self.im_list[idx]))
with torch.no_grad():
with model.ema_scope():
# form condition tensor and define shape of latent rep
xc = {'prev_frame': self.input0_list[idx], 'next_frame': self.input1_list[idx]}
c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
shape = (model.channels, c.shape[2], c.shape[3])
# run sampling and get denoised latent rep
out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
if isinstance(out, tuple): # using ddim
out = out[0]
# reconstruct interpolated frame from latent
out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]
gt = self.gt_list[idx]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input0_list[idx], self.input1_list[idx]])[0].item()
results_dict[metric].append(score)
imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True)
msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Middlebury_others(TripletTestSet):
def __init__(self, db_dir):
super(Middlebury_others, self).__init__()
self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))
class Davis(TripletTestSet):
def __init__(self, db_dir):
super(Davis, self).__init__()
self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn']
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))
class Ucf(TripletTestSet):
def __init__(self, db_dir):
super(Ucf, self).__init__()
self.im_list = os.listdir(db_dir)
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0))
class Snufilm(TripletTestSet):
def __init__(self, db_dir, mode):
super(Snufilm, self).__init__()
self.mode = mode
self.input0_list = []
self.input1_list = []
self.gt_list = []
with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f:
triplet_list = f.read().splitlines()
self.im_list = []
for i, triplet in enumerate(triplet_list, 1):
self.im_list.append('{}-{}'.format(mode, str(i).zfill(3)))
lst = triplet.split(' ')
self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0))
class Snufilm_easy(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-5]
super(Snufilm_easy, self).__init__(db_dir, 'easy')
class Snufilm_medium(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-7]
super(Snufilm_medium, self).__init__(db_dir, 'medium')
class Snufilm_hard(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-5]
super(Snufilm_hard, self).__init__(db_dir, 'hard')
class Snufilm_extreme(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-8]
super(Snufilm_extreme, self).__init__(db_dir, 'extreme')
class VFITex_triplet:
def __init__(self, db_dir):
self.seq_list = os.listdir(db_dir)
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False):
model.eval()
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results.txt'), 'a')
for idx, seq in enumerate(self.seq_list):
if resume and idx < start_idx:
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {seq}')
t0 = time.time()
seqpath = join(self.db_dir, seq)
if not exists(join(output_dir, seq)):
os.makedirs(join(output_dir, seq))
# interpolate between every 2 frames
gt_list, out_list, inputs_list = [], [], []
tmp_dict = {k : [] for k in metrics}
num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')])
for t in range(1, num_frames-5, 2):
im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png'))
im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png'))
im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png'))
# center crop if 4K
if '4K' in seq:
w, h = im0.size
im0 = TF.center_crop(im0, (h//2, w//2))
im1 = TF.center_crop(im1, (h//2, w//2))
im2 = TF.center_crop(im2, (h//2, w//2))
im0 = self.transform(im0).cuda().unsqueeze(0)
im1 = self.transform(im1).cuda().unsqueeze(0)
im2 = self.transform(im2).cuda().unsqueeze(0)
with torch.no_grad():
with model.ema_scope():
# form condition tensor and define shape of latent rep
xc = {'prev_frame': im0, 'next_frame': im2}
c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
shape = (model.channels, c.shape[2], c.shape[3])
# run sampling and get denoised latent rep
out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
if isinstance(out, tuple): # using ddim
out = out[0]
# reconstruct interpolated frame from latent
out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(im1, out, [im0, im2])[0].item()
tmp_dict[metric].append(score)
imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True)
# compute sequence-level scores
for metric in metrics:
results_dict[metric].append(np.mean(tmp_dict[metric]))
# log
msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Davis90_triplet:
def __init__(self, db_dir):
self.seq_list = sorted(os.listdir(db_dir))
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False):
model.eval()
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results.txt'), 'a')
for idx, seq in enumerate(self.seq_list):
if resume and idx < start_idx:
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {seq}')
t0 = time.time()
seqpath = join(self.db_dir, seq)
if not exists(join(output_dir, seq)):
os.makedirs(join(output_dir, seq))
# interpolate between every 2 frames
gt_list, out_list, inputs_list = [], [], []
tmp_dict = {k : [] for k in metrics}
num_frames = len(os.listdir(seqpath))
for t in range(0, num_frames-6, 2):
im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg'))
im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg'))
im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg'))
im3 = self.transform(im3).cuda().unsqueeze(0)
im4 = self.transform(im4).cuda().unsqueeze(0)
im5 = self.transform(im5).cuda().unsqueeze(0)
with torch.no_grad():
with model.ema_scope():
# form condition tensor and define shape of latent rep
xc = {'prev_frame': im3, 'next_frame': im5}
c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
shape = (model.channels, c.shape[2], c.shape[3])
# run sampling and get denoised latent rep
out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
if isinstance(out, tuple): # using ddim
out = out[0]
# reconstruct interpolated frame from latent
out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(im4, out, [im3, im5])[0].item()
tmp_dict[metric].append(score)
imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True)
# compute sequence-level scores
for metric in metrics:
results_dict[metric].append(np.mean(tmp_dict[metric]))
# log
msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Ucf101_triplet:
def __init__(self, db_dir):
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
self.im_list = os.listdir(db_dir)
self.input3_list = []
self.input5_list = []
self.gt_list = []
for item in self.im_list:
self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0))
self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0))
def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False):
model.eval()
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results.txt'), 'a')
for idx in range(len(self.im_list)):
if resume and idx < start_idx:
assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {self.im_list[idx]}')
t0 = time.time()
if not exists(join(output_dir, self.im_list[idx])):
os.makedirs(join(output_dir, self.im_list[idx]))
with torch.no_grad():
with model.ema_scope():
# form condition tensor and define shape of latent rep
xc = {'prev_frame': self.input3_list[idx], 'next_frame': self.input5_list[idx]}
c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
shape = (model.channels, c.shape[2], c.shape[3])
# run sampling and get denoised latent rep
out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
if isinstance(out, tuple): # using ddim
out = out[0]
# reconstruct interpolated frame from latent
out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]
gt = self.gt_list[idx]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input3_list[idx], self.input5_list[idx]])[0].item()
results_dict[metric].append(score)
imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True)
msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
================================================
FILE: ldm/data/testsets_vqm.py
================================================
import glob
from typing import List
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.utils import save_image as imwrite
import os
from os.path import join, exists
import utility
import numpy as np
import ast
import time
class TripletTestSet:
def __init__(self):
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False):
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'No res file found to resume from!'
with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results_vqm.txt'), 'a')
for idx in range(len(self.im_list)):
if resume and idx < start_idx:
assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {self.im_list[idx]}')
t0 = time.time()
assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}'
out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0)
gt = self.gt_list[idx]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input0_list[idx], self.input1_list[idx]])
results_dict[metric].append(score)
msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Middlebury_others(TripletTestSet):
def __init__(self, db_dir):
super(Middlebury_others, self).__init__()
self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))
class Davis(TripletTestSet):
def __init__(self, db_dir):
super(Davis, self).__init__()
self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn']
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))
class Ucf(TripletTestSet):
def __init__(self, db_dir):
super(Ucf, self).__init__()
self.im_list = os.listdir(db_dir)
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0))
class Snufilm(TripletTestSet):
def __init__(self, db_dir, mode):
super(Snufilm, self).__init__()
self.mode = mode
self.input0_list = []
self.input1_list = []
self.gt_list = []
with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f:
triplet_list = f.read().splitlines()
self.im_list = []
for i, triplet in enumerate(triplet_list, 1):
self.im_list.append('{}-{}'.format(mode, str(i).zfill(3)))
lst = triplet.split(' ')
self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0))
self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0))
class Snufilm_easy(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-5]
super(Snufilm_easy, self).__init__(db_dir, 'easy')
class Snufilm_medium(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-7]
super(Snufilm_medium, self).__init__(db_dir, 'medium')
class Snufilm_hard(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-5]
super(Snufilm_hard, self).__init__(db_dir, 'hard')
class Snufilm_extreme(Snufilm):
def __init__(self, db_dir):
db_dir = db_dir[:-8]
super(Snufilm_extreme, self).__init__(db_dir, 'extreme')
class VFITex_triplet:
def __init__(self, db_dir):
self.seq_list = os.listdir(db_dir)
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False):
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results_vqm.txt'), 'a')
for idx, seq in enumerate(self.seq_list):
if resume and idx < start_idx:
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {seq}')
t0 = time.time()
seqpath = join(self.db_dir, seq)
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}'
# interpolate between every 2 frames
gt_list, out_list, inputs_list = [], [], []
num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')])
for t in range(1, num_frames-5, 2):
im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png'))
im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png'))
im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png'))
# center crop if 4K
if '4K' in seq:
w, h = im0.size
im0 = TF.center_crop(im0, (h//2, w//2))
im1 = TF.center_crop(im1, (h//2, w//2))
im2 = TF.center_crop(im2, (h//2, w//2))
im0 = self.transform(im0).cuda().unsqueeze(0)
im1 = self.transform(im1).cuda().unsqueeze(0)
im2 = self.transform(im2).cuda().unsqueeze(0)
out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0)
gt_list.append(im1)
out_list.append(out)
if t == 1:
inputs_list.append(im0)
inputs_list.append(im2)
# compute sequence-level scores
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list)
results_dict[metric].append(score)
# log
msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Davis90_triplet:
def __init__(self, db_dir):
self.seq_list = sorted(os.listdir(db_dir))
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False):
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results_vqm.txt'), 'a')
for idx, seq in enumerate(self.seq_list):
if resume and idx < start_idx:
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {seq}')
t0 = time.time()
seqpath = join(self.db_dir, seq)
assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}'
# interpolate between every 2 frames
gt_list, out_list, inputs_list = [], [], []
num_frames = len(os.listdir(seqpath))
for t in range(0, num_frames-6, 2):
im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg'))
im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg'))
im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg'))
im3 = self.transform(im3).cuda().unsqueeze(0)
im4 = self.transform(im4).cuda().unsqueeze(0)
im5 = self.transform(im5).cuda().unsqueeze(0)
out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0)
gt_list.append(im4)
out_list.append(out)
if t == 0:
inputs_list.append(im3)
inputs_list.append(im5)
# compute sequence-level scores
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list)
results_dict[metric].append(score)
# log
msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
class Ucf101_triplet:
def __init__(self, db_dir):
self.db_dir = db_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]
self.im_list = os.listdir(db_dir)
self.input3_list = []
self.input5_list = []
self.gt_list = []
for item in self.im_list:
self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0))
self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0))
self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0))
def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False):
results_dict = {k : [] for k in metrics}
start_idx = 0
if resume:
# fill in results_dict with prev results and find where to start from
assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
prev_lines = f.readlines()
for line in prev_lines:
if len(line) < 2:
continue
cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
for k in metrics:
results_dict[k].append(float(cur_res[k]))
start_idx += 1
logfile = open(join(output_dir, 'results_vqm.txt'), 'a')
for idx in range(len(self.im_list)):
if resume and idx < start_idx:
assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
continue
print(f'Evaluating {self.im_list[idx]}')
t0 = time.time()
assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}'
out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0)
gt = self.gt_list[idx]
for metric in metrics:
score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input3_list[idx], self.input5_list[idx]])
results_dict[metric].append(score)
msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n'
print(msg, end='')
logfile.write(msg)
msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
print(msg, end='')
logfile.write(msg)
logfile.close()
================================================
FILE: ldm/data/vfitransforms.py
================================================
import random
import torch
import torchvision
import torchvision.transforms.functional as TF
def rand_crop(*args, sz):
i, j, h, w = torchvision.transforms.RandomCrop.get_params(args[0], output_size=sz)
out = []
for im in args:
out.append(TF.crop(im, i, j, h, w))
return out
def rand_flip(*args, p):
out = list(args)
if random.random() < p:
for i, im in enumerate(out):
out[i] = TF.hflip(im)
if random.random() < p:
for i, im in enumerate(out):
out[i] = TF.vflip(im)
return out
def rand_reverse(*args, p):
if random.random() < p:
return args[::-1]
else:
return args
================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np
class LambdaWarmUpCosineScheduler:
"""
note: use with a base_lr of 1.0
"""
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
self.lr_warm_up_steps = warm_up_steps
self.lr_start = lr_start
self.lr_min = lr_min
self.lr_max = lr_max
self.lr_max_decay_steps = max_decay_steps
self.last_lr = 0.
self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs):
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
self.last_lr = lr
return lr
else:
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
t = min(t, 1.0)
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
1 + np.cos(t * np.pi))
self.last_lr = lr
return lr
def __call__(self, n, **kwargs):
return self.schedule(n,**kwargs)
class LambdaWarmUpCosineScheduler2:
"""
supports repeated iterations, configurable via lists
note: use with a base_lr of 1.0.
"""
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
self.lr_warm_up_steps = warm_up_steps
self.f_start = f_start
self.f_min = f_min
self.f_max = f_max
self.cycle_lengths = cycle_lengths
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
self.last_f = 0.
self.verbosity_interval = verbosity_interval
def find_in_interval(self, n):
interval = 0
for cl in self.cum_cycles[1:]:
if n <= cl:
return interval
interval += 1
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
t = min(t, 1.0)
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
1 + np.cos(t * np.pi))
self.last_f = f
return f
def __call__(self, n, **kwargs):
return self.schedule(n, **kwargs)
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}")
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
self.last_f = f
return f
else:
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
self.last_f = f
return f
================================================
FILE: ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from packaging import version
from ldm.modules.ema import LitEma
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.modules.diffusionmodules.model import *
from ldm.util import instantiate_from_config
class VQFlowNet(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim # 3
self.n_embed = n_embed # 8192
self.image_key = image_key # 'image'
self.encoder = FlowEncoder(**ddconfig)
self.decoder = FlowDecoderWithResidual(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
self.h0 = None
self.w0 = None
self.h_padded = None
self.w_padded = None
self.pad_h = 0
self.pad_w = 0
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x, ret_feature=False):
'''
Set ret_feature = True when encoding conditions in ddpm
'''
# Pad the input first so its size is deividable by 8.
# this is to tolerate different f values, various size inputs,
# and some operations in the DDPM unet model.
self.h0, self.w0 = x.shape[2:]
# 8: window size for max vit
# 2**(nr-1): f
# 4: factor of downsampling in DDPM unet
min_side = 8 * 2**(self.encoder.num_resolutions-1) * 4
if self.h0 % min_side != 0:
pad_h = min_side - (self.h0 % min_side)
if pad_h == self.h0: # this is to avoid padding 256 patches
pad_h = 0
x = F.pad(x, (0, 0, 0, pad_h), mode='reflect')
self.h_padded = True
self.pad_h = pad_h
if self.w0 % min_side != 0:
pad_w = min_side - (self.w0 % min_side)
if pad_w == self.w0:
pad_w = 0
x = F.pad(x, (0, pad_w, 0, 0), mode='reflect')
self.w_padded = True
self.pad_w = pad_w
phi_list = None
if ret_feature:
h, phi_list = self.encoder(x, ret_feature)
else:
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
if ret_feature:
return quant, emb_loss, info, phi_list
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant, x_prev, x_next):
cond_dict = dict(
phi_prev_list = self.encode(x_prev, ret_feature=True)[-1],
phi_next_list = self.encode(x_next, ret_feature=True)[-1],
frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'),
frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect')
)
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, cond_dict)
# check if image is padded and return the original part only
if self.h_padded:
dec = dec[:, :, 0:self.h0, :]
if self.w_padded:
dec = dec[:, :, :, 0:self.w0]
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, x_prev, x_next, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant, x_prev, x_next)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
x_prev = self.get_input(batch, 'prev_frame')
x_next = self.get_input(batch, 'next_frame')
xrec, qloss = self(x, x_prev, x_next)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
x_prev = self.get_input(batch, 'prev_frame')
x_next = self.get_input(batch, 'next_frame')
xrec, qloss = self(x, x_prev, x_next)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x_prev = self.get_input(batch, 'prev_frame')
x_next = self.get_input(batch, 'next_frame')
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x, x_prev, x_next)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x, x_prev, x_next)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQFlowNetInterface(VQFlowNet):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def encode(self, x, ret_feature=False):
'''
Set ret_feature = True when encoding conditions in ddpm
'''
# Pad the input first so its size is deividable by 8.
# this is to tolerate different f values, various size inputs,
# and some operations in the DDPM unet model.
self.h0, self.w0 = x.shape[2:]
# 8: window size for max vit
# 2**(nr-1): f
# 4: factor of downsampling in DDPM unet
min_side = 512#8 * 2**(self.encoder.num_resolutions-1) * 16
min_side = min_side // 2 if self.h0 <= 256 else min_side
if self.h0 % min_side != 0:
pad_h = min_side - (self.h0 % min_side)
if pad_h == self.h0: # this is to avoid padding 256 patches
pad_h = 0
x = F.pad(x, (0, 0, 0, pad_h), mode='reflect')
self.h_padded = True
self.pad_h = pad_h
if self.w0 % min_side != 0:
pad_w = min_side - (self.w0 % min_side)
if pad_w == self.w0:
pad_w = 0
x = F.pad(x, (0, pad_w, 0, 0), mode='reflect')
self.w_padded = True
self.pad_w = pad_w
phi_list = None
if ret_feature:
h, phi_list = self.encoder(x, ret_feature)
else:
h = self.encoder(x)
h = self.quant_conv(h)
if ret_feature:
return h, phi_list
return h
def decode(self, h, x_prev, x_next, phi_prev_list, phi_next_list, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
cond_dict = dict(
phi_prev_list = phi_prev_list,
phi_next_list = phi_next_list,
frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'),
frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect')
)
quant = self.post_quant_conv(quant)
dec = self.decoder(quant, cond_dict)
# check if image is padded and return the original part only
if self.h_padded:
dec = dec[:, :, 0:self.h0, :]
if self.w_padded:
dec = dec[:, :, :, 0:self.w0]
return dec
================================================
FILE: ldm/models/diffusion/__init__.py
================================================
================================================
FILE: ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
if verbose:
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
verbose=verbose
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,verbose=True):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
if verbose:
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) if verbose else time_range
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
================================================
FILE: ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only
from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.models.autoencoder import *
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler
__conditioning_keys__ = {'concat': 'c_concat',
'crossattn': 'c_crossattn',
'adm': 'y'}
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class DDPM(pl.LightningModule):
# classic DDPM with Gaussian diffusion, in image space
def __init__(self,
unet_config,
timesteps=1000,
beta_schedule="linear",
loss_type="l2",
ckpt_path=None,
ignore_keys=[],
load_only_unet=False,
monitor="val/loss",
use_ema=True,
first_stage_key="image",
image_size=256,
channels=3,
log_every_t=100,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
original_elbo_weight=0.,
v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
l_simple_weight=1.,
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
scheduler_config=None,
use_positional_encodings=False,
learn_logvar=False,
logvar_init=0.,
):
super().__init__()
assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.clip_denoised = clip_denoised
self.log_every_t = log_every_t
self.first_stage_key = first_stage_key
self.image_size = image_size # try conv?
self.channels = channels
self.use_positional_encodings = use_positional_encodings
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.use_scheduler = scheduler_config is not None
if self.use_scheduler:
self.scheduler_config = scheduler_config
self.v_posterior = v_posterior
self.original_elbo_weight = original_elbo_weight
self.l_simple_weight = l_simple_weight
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
self.loss_type = loss_type
self.learn_logvar = learn_logvar
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
if self.learn_logvar:
self.logvar = nn.Parameter(self.logvar, requires_grad=True)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / (
2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
elif self.parameterization == "x0":
lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
sd = torch.load(path, map_location="cpu")
if "state_dict" in list(sd.keys()):
sd = sd["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
intermediates = [img]
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
intermediates.append(img)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
if mean:
loss = loss.mean()
elif self.loss_type == 'l2':
if mean:
loss = torch.nn.functional.mse_loss(target, pred)
else:
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
else:
raise NotImplementedError("unknown loss type '{loss_type}'")
return loss
def p_losses(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x_noisy, t)
loss_dict = {}
if self.parameterization == "eps":
target = noise
elif self.parameterization == "x0":
target = x_start
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")
loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
log_prefix = 'train' if self.training else 'val'
loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
loss_simple = loss.mean() * self.l_simple_weight
loss_vlb = (self.lvlb_weights[t] * loss).mean()
loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
loss = loss_simple + self.original_elbo_weight * loss_vlb
loss_dict.update({f'{log_prefix}/loss': loss})
return loss, loss_dict
def forward(self, x, *args, **kwargs):
# b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
# assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
return self.p_losses(x, t, *args, **kwargs)
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = rearrange(x, 'b h w c -> b c h w')
x = x.to(memory_format=torch.contiguous_format).float()
return x
def shared_step(self, batch):
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", float(self.global_step),
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@torch.no_grad()
def validation_step(self, batch, batch_idx):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self.model)
def _get_rows_from_list(self, samples):
n_imgs_per_row = len(samples)
denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
@torch.no_grad()
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
log = dict()
x = self.get_input(batch, self.first_stage_key)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
x = x.to(self.device)[:N]
log["inputs"] = x
# get diffusion row
diffusion_row = list()
x_start = x[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(x_start)
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
diffusion_row.append(x_noisy)
log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
log["samples"] = samples
log["denoise_row"] = self._get_rows_from_list(denoise_row)
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.learn_logvar:
params = params + [self.logvar]
opt = torch.optim.AdamW(params, lr=lr)
return opt
class LatentDiffusion(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
self.scale_factor = scale_factor
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device),
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
c = self.cond_stage_model.encode(c)
else:
c = self.cond_stage_model(c)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
return c
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, bs=None):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
if self.model.conditioning_key is not None:
if cond_key is None:
cond_key = self.cond_stage_key
if cond_key != self.first_stage_key:
if cond_key in ['caption', 'coordinates_bbox']:
xc = batch[cond_key]
elif cond_key == 'class_label':
xc = batch
else:
xc = super().get_input(batch, cond_key).to(self.device)
else:
xc = x
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
ckey = __conditioning_keys__[self.model.conditioning_key]
c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
else:
c = None
xc = None
if self.use_positional_encodings:
pos_x, pos_y = self.compute_latent_shifts(batch)
c = {'pos_x': pos_x, 'pos_y': pos_y}
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
return out
@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z)
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
# if return_codebook_ids:
# return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
mask=mask, x0=x0)
@torch.no_grad()
def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
if ddim:
ddim_sampler = DDIMSampler(self)
shape = (self.channels, self.image_size, self.image_size)
samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
shape,cond,verbose=False,**kwargs)
else:
samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
return_intermediates=True,**kwargs)
return samples, intermediates
@torch.no_grad()
def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
plot_diffusion_rows=True, **kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=True,
return_original_cond=True,
bs=N)
N = min(x.shape[0], N)
n_row = min(x.shape[0], n_row)
log["inputs"] = x
log["reconstruction"] = xrec
if self.model.conditioning_key is not None:
if hasattr(self.cond_stage_model, "decode"):
xc = self.cond_stage_model.decode(c)
log["conditioning"] = xc
elif self.cond_stage_key in ["caption"]:
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
log["conditioning"] = xc
elif self.cond_stage_key == 'class_label':
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
log['conditioning'] = xc
elif isimage(xc):
log["conditioning"] = xc
if ismap(xc):
log["original_conditioning"] = self.to_rgb(xc)
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
log["diffusion_row"] = diffusion_grid
if sample:
# get denoise row
with self.ema_scope("Plotting"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
x_samples = self.decode_first_stage(samples)
log["samples"] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log["denoise_row"] = denoise_grid
if quantize_denoised:
# also display when quantizing x0 while sampling
with self.ema_scope("Plotting Quantized Denoised"):
samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
ddim_steps=ddim_steps,eta=ddim_eta,
quantize_denoised=True)
# samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
# quantize_denoised=True)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_x0_quantized"] = x_samples
if inpaint:
# make a simple center square
b, h, w = z.shape[0], z.shape[2], z.shape[3]
mask = torch.ones(N, h, w).to(self.device)
# zeros will be filled in
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
mask = mask[:, None, ...]
with self.ema_scope("Plotting Inpaint"):
samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_inpainting"] = x_samples
log["mask"] = mask
# outpaint
with self.ema_scope("Plotting Outpaint"):
samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
ddim_steps=ddim_steps, x0=z[:N], mask=mask)
x_samples = self.decode_first_stage(samples.to(self.device))
log["samples_outpainting"] = x_samples
if plot_progressive_rows:
with self.ema_scope("Plotting Progressives"):
img, progressives = self.progressive_denoising(c,
shape=(self.channels, self.image_size, self.image_size),
batch_size=N)
prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
log["progressive_row"] = prog_row
if return_keys:
if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
return log
else:
return {key: log[key] for key in return_keys}
return log
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.model.parameters())
if self.cond_stage_trainable:
print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
params = params + list(self.cond_stage_model.parameters())
if self.learn_logvar:
print('Diffusion model optimizing logvar')
params.append(self.logvar)
opt = torch.optim.AdamW(params, lr=lr)
if self.use_scheduler:
assert 'target' in self.scheduler_config
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
}]
return [opt], scheduler
return opt
@torch.no_grad()
def to_rgb(self, x):
x = x.float()
if not hasattr(self, "colorize"):
self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
x = nn.functional.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class DiffusionWrapper(pl.LightningModule):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'mcvd']
def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':
xc = torch.cat([x] + c_concat, dim=1)
out = self.diffusion_model(xc, t)
elif self.conditioning_key == 'crossattn':
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(x, t, context=cc)
elif self.conditioning_key == 'hybrid':
xc = torch.cat([x] + c_concat, dim=1)
cc = torch.cat(c_crossattn, 1)
out = self.diffusion_model(xc, t, context=cc)
elif self.conditioning_key == 'adm':
cc = c_crossattn[0]
out = self.diffusion_model(x, t, y=cc)
elif self.conditioning_key == 'mcvd':
out = self.diffusion_model(x, t, cond=c_concat[0])
else:
raise NotImplementedError()
return out
#########################################################
#########################################################
#########################################################
#########################################################
#########################################################
#########################################################
class LatentDiffusionVFI(DDPM):
"""main class"""
def __init__(self,
first_stage_config,
cond_stage_config,
num_timesteps_cond=None,
cond_stage_key="image",
cond_stage_trainable=False,
concat_mode=True,
cond_stage_forward=None,
conditioning_key=None,
scale_factor=1.0,
*args, **kwargs):
self.num_timesteps_cond = default(num_timesteps_cond, 1)
assert self.num_timesteps_cond <= kwargs['timesteps']
# for backwards compatibility after implementation of DiffusionWrapper
if conditioning_key is None:
conditioning_key = 'concat' if concat_mode else 'crossattn'
if cond_stage_config == '__is_unconditional__':
conditioning_key = None
ckpt_path = kwargs.pop("ckpt_path", None)
ignore_keys = kwargs.pop("ignore_keys", [])
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
self.concat_mode = concat_mode
self.cond_stage_trainable = cond_stage_trainable
self.cond_stage_key = cond_stage_key
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
self.scale_factor = scale_factor
self.instantiate_first_stage(first_stage_config)
self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False
self.restarted_from_ckpt = False
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys)
self.restarted_from_ckpt = True
def make_cond_schedule(self, ):
self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
self.cond_ids[:self.num_timesteps_cond] = ids
def register_schedule(self,
given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def instantiate_cond_stage(self, config):
if not self.cond_stage_trainable:
if config == "__is_first_stage__":
print("Using first stage also as cond stage.")
self.cond_stage_model = self.first_stage_model
elif config == "__is_unconditional__":
print(f"Training {self.__class__.__name__} as an unconditional model.")
self.cond_stage_model = None
# self.be_unconditional = True
else:
model = instantiate_from_config(config)
self.cond_stage_model = model.eval()
self.cond_stage_model.train = disabled_train
for param in self.cond_stage_model.parameters():
param.requires_grad = False
else:
assert config != '__is_first_stage__'
assert config != '__is_unconditional__'
model = instantiate_from_config(config)
self.cond_stage_model = model
def _get_denoise_row_from_list(self, samples, xc=None, phi_prev_list=None, phi_next_list=None, desc='', force_no_decoder_quantization=False):
denoise_row = []
for zd in tqdm(samples, desc=desc):
denoise_row.append(self.decode_first_stage(zd.to(self.device), xc, phi_prev_list, phi_next_list,
force_not_quantize=force_no_decoder_quantization))
n_imgs_per_row = len(denoise_row)
denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
return denoise_grid
def get_first_stage_encoding(self, encoder_posterior):
if isinstance(encoder_posterior, torch.Tensor):
z = encoder_posterior
else:
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
return self.scale_factor * z
def get_learned_conditioning(self, c):
phi_prev_list, phi_next_list = None, None
if isinstance(c, dict) and 'prev_frame' in c.keys():
c_prev, phi_prev_list = self.cond_stage_model.encode(c['prev_frame'], ret_feature=True)
c_next, phi_next_list = self.cond_stage_model.encode(c['next_frame'], ret_feature=True)
c = torch.cat([c_prev, c_next], dim=1)
else:
c = self.cond_stage_model.encode(c)
return c, phi_prev_list, phi_next_list
@torch.no_grad()
def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
cond_key=None, return_original_cond=False, return_phi=False, bs=None):
x = super().get_input(batch, k)
if bs is not None:
x = x[:bs]
x = x.to(self.device)
encoder_posterior = self.encode_first_stage(x)
z = self.get_first_stage_encoding(encoder_posterior).detach()
phi_prev_list, phi_next_list = None, None
assert self.model.conditioning_key is not None
if cond_key == None:
cond_key = self.cond_stage_key
assert cond_key == 'past_future_frames'
xc = {'prev_frame': super().get_input(batch, 'prev_frame'),
'next_frame': super().get_input(batch, 'next_frame')}
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc)
else:
c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
if bs is not None:
c = c[:bs]
if isinstance(xc, dict):
xc['prev_frame'] = xc['prev_frame'][:bs]
xc['next_frame'] = xc['next_frame'][:bs]
if phi_prev_list and phi_next_list:
phi_prev_list = [phi_prev[:bs] for phi_prev in phi_prev_list]
phi_next_list = [phi_next[:bs] for phi_next in phi_next_list]
out = [z, c]
if return_first_stage_outputs:
xrec = self.decode_first_stage(z, xc, phi_prev_list, phi_next_list)
out.extend([x, xrec])
if return_original_cond:
out.append(xc)
if return_phi:
out.append(phi_prev_list)
out.append(phi_next_list)
return out
@torch.no_grad()
def decode_first_stage(self, z, xc=None, phi_prev_list=None, phi_next_list=None, predict_cids=False, force_not_quantize=False):
if predict_cids:
if z.dim() == 4:
z = torch.argmax(z.exp(), dim=1).long()
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
z = rearrange(z, 'b h w c -> b c h w').contiguous()
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z, xc['prev_frame'], xc['next_frame'], phi_prev_list, phi_next_list, force_not_quantize=predict_cids or force_not_quantize)
@torch.no_grad()
def encode_first_stage(self, x):
return self.first_stage_model.encode(x)
def shared_step(self, batch, **kwargs):
x, c = self.get_input(batch, self.first_stage_key)
loss = self(x, c)
return loss
def forward(self, x, c, *args, **kwargs):
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable:
c, _, _ = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
def apply_model(self, x_noisy, t, cond, return_ids=False):
if isinstance(cond, dict):
# hybrid case, cond is exptected to be a dict
pass
else:
if not isinstance(cond, list):
cond = [cond]
key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
cond = {key: cond}
x_recon = self.model(x_noisy, t, **cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
def p_losses(self, x_start, cond, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
model_output = self.apply_model(x_noisy, t, cond)
loss_dict = {}
prefix = 'train' if self.training else 'val'
if self.parameterization == "x0":
target = x_start
elif self.parameterization == "eps":
target = noise
else:
raise NotImplementedError()
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
logvar_t = self.logvar[t].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
loss += (self.original_elbo_weight * loss_vlb)
loss_dict.update({f'{prefix}/loss': loss})
return loss, loss_dict
def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
return_x0=False, score_corrector=None, corrector_kwargs=None):
t_in = t
model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
if score_corrector is not None:
assert self.parameterization == "eps"
model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
if return_codebook_ids:
model_out, logits = model_out
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
else:
raise NotImplementedError()
if clip_denoised:
x_recon.clamp_(-1., 1.)
if quantize_denoised:
x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
if return_codebook_ids:
return model_mean, posterior_variance, posterior_log_variance, logits
elif return_x0:
return model_mean, posterior_variance, posterior_log_variance, x_recon
else:
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
return_codebook_ids=False, quantize_denoised=False, return_x0=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
b, *_, device = *x.shape, x.device
outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
return_codebook_ids=return_codebook_ids,
quantize_denoised=quantize_denoised,
return_x0=return_x0,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
model_mean, _, model_log_variance = outputs
noise = noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
# if return_codebook_ids:
# return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
timesteps = self.num_timesteps
if batch_size is not None:
b = batch_size if batch_size is not None else shape[0]
shape = [batch_size] + list(shape)
else:
b = batch_size = shape[0]
if x_T is None:
img = torch.randn(shape, device=self.device)
else:
img = x_T
intermediates = []
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
total=timesteps) if verbose else reversed(
range(0, timesteps))
if type(temperature) == float:
temperature = [temperature] * timesteps
for i in iterator:
ts = torch.full((b,), i, device=self.device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img, x0_partial = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised, return_x0=True,
temperature=temperature[i], noise_dropout=noise_dropout,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(x0_partial)
if callback: callback(i)
if img_callback: img_callback(img, i)
return img, intermediates
@torch.no_grad()
def p_sample_loop(self, cond, shape, return_intermediates=False,
x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, start_T=None,
log_every_t=None):
if not log_every_t:
log_every_t = self.log_every_t
device = self.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
intermediates = [img]
if timesteps is None:
timesteps = self.num_timesteps
if start_T is not None:
timesteps = min(timesteps, start_T)
iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
range(0, timesteps))
if mask is not None:
assert x0 is not None
assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
for i in iterator:
ts = torch.full((b,), i, device=device, dtype=torch.long)
if self.shorten_cond_schedule:
assert self.model.conditioning_key != 'hybrid'
tc = self.cond_ids[ts].to(cond.device)
cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
img = self.p_sample(img, cond, ts,
clip_denoised=self.clip_denoised,
quantize_denoised=quantize_denoised)
if mask is not None:
img_orig = self.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
if i % log_every_t == 0 or i == timesteps - 1:
intermediates.append(img)
if callback: callback(i)
if img_callback: img_callback(img, i)
if return_intermediates:
return img, intermediates
return img
@torch.no_grad()
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None,**kwargs):
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:
if isinstance(cond, dict):
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
else:
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
return self.p_sample_loop(cond,
shape,
return_intermediates=return_intermediates, x_T=x_T,
gitextract_grs9wzig/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── autoencoder/ │ │ └── vqflow-f32.yaml │ └── ldm/ │ └── ldmvfi-vqflow-f32-c256-concat_max.yaml ├── cupy_module/ │ ├── __init__.py │ └── dsepconv.py ├── environment.yaml ├── evaluate.py ├── evaluate_vqm.py ├── interpolate_yuv.py ├── ldm/ │ ├── data/ │ │ ├── __init__.py │ │ ├── bvi_vimeo.py │ │ ├── testsets.py │ │ ├── testsets_vqm.py │ │ └── vfitransforms.py │ ├── lr_scheduler.py │ ├── models/ │ │ ├── autoencoder.py │ │ └── diffusion/ │ │ ├── __init__.py │ │ ├── ddim.py │ │ └── ddpm.py │ ├── modules/ │ │ ├── attention.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── ema.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ └── vqperceptual.py │ │ └── maxvit.py │ └── util.py ├── main.py ├── metrics/ │ ├── flolpips/ │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── correlation/ │ │ │ └── correlation.py │ │ ├── flolpips.py │ │ ├── pretrained_networks.py │ │ ├── pwcnet.py │ │ └── utils.py │ ├── lpips/ │ │ ├── __init__.py │ │ ├── lpips.py │ │ └── pretrained_networks.py │ └── pytorch_ssim/ │ └── __init__.py ├── setup.py └── utility.py
SYMBOL INDEX (577 symbols across 31 files)
FILE: cupy_module/dsepconv.py
class Stream (line 7) | class Stream:
function cupy_kernel (line 451) | def cupy_kernel(strFunction, objectVariables):
function cupy_launch (line 511) | def cupy_launch(strFunction, strKernel):
class _FunctionDSepconv (line 518) | class _FunctionDSepconv(torch.autograd.Function):
method forward (line 520) | def forward(self, input, vertical, horizontal, offset_x, offset_y, mask):
method backward (line 571) | def backward(self, gradOutput):
function FunctionDSepconv (line 701) | def FunctionDSepconv(tensorInput, tensorVertical, tensorHorizontal, tens...
class ModuleDSepconv (line 707) | class ModuleDSepconv(torch.nn.Module):
method __init__ (line 708) | def __init__(self):
method forward (line 713) | def forward(self, tensorInput, tensorVertical, tensorHorizontal, tenso...
FILE: evaluate.py
function main (line 26) | def main():
FILE: evaluate_vqm.py
function main (line 16) | def main():
FILE: interpolate_yuv.py
function main (line 31) | def main():
FILE: ldm/data/bvi_vimeo.py
class Vimeo90k_triplet (line 11) | class Vimeo90k_triplet(Dataset):
method __init__ (line 12) | def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=T...
method __getitem__ (line 29) | def __getitem__(self, index):
method __len__ (line 52) | def __len__(self):
class Vimeo90k_quintuplet (line 56) | class Vimeo90k_quintuplet(Dataset):
method __init__ (line 57) | def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=T...
method __getitem__ (line 74) | def __getitem__(self, index):
method __len__ (line 94) | def __len__(self):
class BVIDVC_triplet (line 98) | class BVIDVC_triplet(Dataset):
method __init__ (line 99) | def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True...
method __getitem__ (line 107) | def __getitem__(self, index):
method __len__ (line 133) | def __len__(self):
class BVIDVC_quintuplet (line 137) | class BVIDVC_quintuplet(Dataset):
method __init__ (line 138) | def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True...
method __getitem__ (line 146) | def __getitem__(self, index):
method __len__ (line 166) | def __len__(self):
class Sampler (line 170) | class Sampler(Dataset):
method __init__ (line 171) | def __init__(self, datasets, p_datasets=None, iter=False, samples_per_...
method __getitem__ (line 186) | def __getitem__(self, index):
method __len__ (line 199) | def __len__(self):
class BVI_Vimeo_triplet (line 206) | class BVI_Vimeo_triplet(Dataset):
method __init__ (line 207) | def __init__(self, db_dir, crop_sz=[256,256], p_datasets=None, iter=Fa...
method __getitem__ (line 225) | def __getitem__(self, index):
method __len__ (line 238) | def __len__(self):
FILE: ldm/data/testsets.py
class TripletTestSet (line 17) | class TripletTestSet:
method __init__ (line 18) | def __init__(self):
method eval (line 21) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
class Middlebury_others (line 80) | class Middlebury_others(TripletTestSet):
method __init__ (line 81) | def __init__(self, db_dir):
class Davis (line 93) | class Davis(TripletTestSet):
method __init__ (line 94) | def __init__(self, db_dir):
class Ucf (line 107) | class Ucf(TripletTestSet):
method __init__ (line 108) | def __init__(self, db_dir):
class Snufilm (line 121) | class Snufilm(TripletTestSet):
method __init__ (line 122) | def __init__(self, db_dir, mode):
class Snufilm_easy (line 139) | class Snufilm_easy(Snufilm):
method __init__ (line 140) | def __init__(self, db_dir):
class Snufilm_medium (line 144) | class Snufilm_medium(Snufilm):
method __init__ (line 145) | def __init__(self, db_dir):
class Snufilm_hard (line 149) | class Snufilm_hard(Snufilm):
method __init__ (line 150) | def __init__(self, db_dir):
class Snufilm_extreme (line 154) | class Snufilm_extreme(Snufilm):
method __init__ (line 155) | def __init__(self, db_dir):
class VFITex_triplet (line 159) | class VFITex_triplet:
method __init__ (line 160) | def __init__(self, db_dir):
method eval (line 166) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
class Davis90_triplet (line 253) | class Davis90_triplet:
method __init__ (line 254) | def __init__(self, db_dir):
method eval (line 260) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
class Ucf101_triplet (line 341) | class Ucf101_triplet:
method __init__ (line 342) | def __init__(self, db_dir):
method eval (line 356) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
FILE: ldm/data/testsets_vqm.py
class TripletTestSet (line 16) | class TripletTestSet:
method __init__ (line 17) | def __init__(self):
method eval (line 20) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='out...
class Middlebury_others (line 65) | class Middlebury_others(TripletTestSet):
method __init__ (line 66) | def __init__(self, db_dir):
class Davis (line 78) | class Davis(TripletTestSet):
method __init__ (line 79) | def __init__(self, db_dir):
class Ucf (line 92) | class Ucf(TripletTestSet):
method __init__ (line 93) | def __init__(self, db_dir):
class Snufilm (line 106) | class Snufilm(TripletTestSet):
method __init__ (line 107) | def __init__(self, db_dir, mode):
class Snufilm_easy (line 124) | class Snufilm_easy(Snufilm):
method __init__ (line 125) | def __init__(self, db_dir):
class Snufilm_medium (line 129) | class Snufilm_medium(Snufilm):
method __init__ (line 130) | def __init__(self, db_dir):
class Snufilm_hard (line 134) | class Snufilm_hard(Snufilm):
method __init__ (line 135) | def __init__(self, db_dir):
class Snufilm_extreme (line 139) | class Snufilm_extreme(Snufilm):
method __init__ (line 140) | def __init__(self, db_dir):
class VFITex_triplet (line 144) | class VFITex_triplet:
method __init__ (line 145) | def __init__(self, db_dir):
method eval (line 151) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None...
class Davis90_triplet (line 223) | class Davis90_triplet:
method __init__ (line 224) | def __init__(self, db_dir):
method eval (line 230) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None...
class Ucf101_triplet (line 296) | class Ucf101_triplet:
method __init__ (line 297) | def __init__(self, db_dir):
method eval (line 311) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='out...
FILE: ldm/data/vfitransforms.py
function rand_crop (line 7) | def rand_crop(*args, sz):
function rand_flip (line 15) | def rand_flip(*args, p):
function rand_reverse (line 26) | def rand_reverse(*args, p):
FILE: ldm/lr_scheduler.py
class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
method schedule (line 17) | def schedule(self, n, **kwargs):
method __call__ (line 32) | def __call__(self, n, **kwargs):
class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
method find_in_interval (line 52) | def find_in_interval(self, n):
method schedule (line 59) | def schedule(self, n, **kwargs):
method __call__ (line 77) | def __call__(self, n, **kwargs):
class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
method schedule (line 83) | def schedule(self, n, **kwargs):
FILE: ldm/models/autoencoder.py
class VQFlowNet (line 18) | class VQFlowNet(pl.LightningModule):
method __init__ (line 19) | def __init__(self,
method ema_scope (line 74) | def ema_scope(self, context=None):
method init_from_ckpt (line 88) | def init_from_ckpt(self, path, ignore_keys=list()):
method on_train_batch_end (line 102) | def on_train_batch_end(self, *args, **kwargs):
method encode (line 106) | def encode(self, x, ret_feature=False):
method encode_to_prequant (line 145) | def encode_to_prequant(self, x):
method decode (line 150) | def decode(self, quant, x_prev, x_next):
method decode_code (line 167) | def decode_code(self, code_b):
method forward (line 172) | def forward(self, input, x_prev, x_next, return_pred_indices=False):
method get_input (line 180) | def get_input(self, batch, k):
method training_step (line 198) | def training_step(self, batch, batch_idx, optimizer_idx):
method validation_step (line 221) | def validation_step(self, batch, batch_idx):
method _validation_step (line 227) | def _validation_step(self, batch, batch_idx, suffix=""):
method configure_optimizers (line 254) | def configure_optimizers(self):
method get_last_layer (line 287) | def get_last_layer(self):
method log_images (line 290) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
method to_rgb (line 314) | def to_rgb(self, x):
class VQFlowNetInterface (line 323) | class VQFlowNetInterface(VQFlowNet):
method __init__ (line 324) | def __init__(self, **kwargs):
method encode (line 327) | def encode(self, x, ret_feature=False):
method decode (line 366) | def decode(self, h, x_prev, x_next, phi_prev_list, phi_next_list, forc...
FILE: ldm/models/diffusion/ddim.py
class DDIMSampler (line 11) | class DDIMSampler(object):
method __init__ (line 12) | def __init__(self, model, schedule="linear", **kwargs):
method register_buffer (line 18) | def register_buffer(self, name, attr):
method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
method sample (line 56) | def sample(self,
method ddim_sampling (line 115) | def ddim_sampling(self, cond, shape,
method p_sample_ddim (line 168) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...
FILE: ldm/models/diffusion/ddpm.py
function disabled_train (line 33) | def disabled_train(self, mode=True):
class DDPM (line 40) | class DDPM(pl.LightningModule):
method __init__ (line 42) | def __init__(self,
method register_schedule (line 113) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
method ema_scope (line 168) | def ema_scope(self, context=None):
method init_from_ckpt (line 182) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
method q_mean_variance (line 200) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 212) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 218) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 227) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 240) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 249) | def p_sample_loop(self, shape, return_intermediates=False):
method sample (line 264) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 270) | def q_sample(self, x_start, t, noise=None):
method get_loss (line 275) | def get_loss(self, pred, target, mean=True):
method p_losses (line 290) | def p_losses(self, x_start, t, noise=None):
method forward (line 319) | def forward(self, x, *args, **kwargs):
method get_input (line 325) | def get_input(self, batch, k):
method shared_step (line 333) | def shared_step(self, batch):
method training_step (line 338) | def training_step(self, batch, batch_idx):
method validation_step (line 354) | def validation_step(self, batch, batch_idx):
method on_train_batch_end (line 362) | def on_train_batch_end(self, *args, **kwargs):
method _get_rows_from_list (line 366) | def _get_rows_from_list(self, samples):
method log_images (line 374) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
method configure_optimizers (line 411) | def configure_optimizers(self):
class LatentDiffusion (line 420) | class LatentDiffusion(DDPM):
method __init__ (line 422) | def __init__(self,
method make_cond_schedule (line 462) | def make_cond_schedule(self, ):
method register_schedule (line 468) | def register_schedule(self,
method instantiate_first_stage (line 477) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 484) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 505) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
method get_first_stage_encoding (line 517) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 524) | def get_learned_conditioning(self, c):
method get_input (line 537) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 589) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
method encode_first_stage (line 602) | def encode_first_stage(self, x):
method shared_step (line 605) | def shared_step(self, batch, **kwargs):
method forward (line 610) | def forward(self, x, c, *args, **kwargs):
method apply_model (line 622) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method p_losses (line 641) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 676) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 708) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 739) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 795) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 846) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_log (line 864) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 880) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 989) | def configure_optimizers(self):
method to_rgb (line 1014) | def to_rgb(self, x):
class DiffusionWrapper (line 1023) | class DiffusionWrapper(pl.LightningModule):
method __init__ (line 1024) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 1030) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
class LatentDiffusionVFI (line 1063) | class LatentDiffusionVFI(DDPM):
method __init__ (line 1065) | def __init__(self,
method make_cond_schedule (line 1105) | def make_cond_schedule(self, ):
method register_schedule (line 1111) | def register_schedule(self,
method instantiate_first_stage (line 1120) | def instantiate_first_stage(self, config):
method instantiate_cond_stage (line 1127) | def instantiate_cond_stage(self, config):
method _get_denoise_row_from_list (line 1148) | def _get_denoise_row_from_list(self, samples, xc=None, phi_prev_list=N...
method get_first_stage_encoding (line 1161) | def get_first_stage_encoding(self, encoder_posterior):
method get_learned_conditioning (line 1168) | def get_learned_conditioning(self, c):
method get_input (line 1180) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
method decode_first_stage (line 1228) | def decode_first_stage(self, z, xc=None, phi_prev_list=None, phi_next_...
method encode_first_stage (line 1241) | def encode_first_stage(self, x):
method shared_step (line 1244) | def shared_step(self, batch, **kwargs):
method forward (line 1249) | def forward(self, x, c, *args, **kwargs):
method apply_model (line 1261) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method p_losses (line 1280) | def p_losses(self, x_start, cond, t, noise=None):
method p_mean_variance (line 1315) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
method p_sample (line 1347) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
method progressive_denoising (line 1378) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
method p_sample_loop (line 1434) | def p_sample_loop(self, cond, shape, return_intermediates=False,
method sample (line 1485) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
method sample_ddpm (line 1502) | def sample_ddpm(self, conditioning, batch_size=16, return_intermediate...
method sample_log (line 1522) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
method log_images (line 1538) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
method configure_optimizers (line 1620) | def configure_optimizers(self):
method to_rgb (line 1645) | def to_rgb(self, x):
FILE: ldm/modules/attention.py
function exists (line 11) | def exists(val):
function uniq (line 15) | def uniq(arr):
function default (line 19) | def default(val, d):
function max_neg_value (line 25) | def max_neg_value(t):
function init_ (line 29) | def init_(tensor):
class GEGLU (line 37) | class GEGLU(nn.Module):
method __init__ (line 38) | def __init__(self, dim_in, dim_out):
method forward (line 42) | def forward(self, x):
class FeedForward (line 47) | class FeedForward(nn.Module):
method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
method forward (line 63) | def forward(self, x):
function zero_module (line 67) | def zero_module(module):
function Normalize (line 76) | def Normalize(in_channels):
class LinearAttention (line 80) | class LinearAttention(nn.Module):
method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 88) | def forward(self, x):
class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels):
method forward (line 126) | def forward(self, x):
class CrossAttention (line 152) | class CrossAttention(nn.Module):
method __init__ (line 157) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 174) | def forward(self, x, context=None, mask=None):
class SpatialCrossAttention (line 200) | class SpatialCrossAttention(nn.Module):
method __init__ (line 207) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 226) | def forward(self, x, context=None):
function posemb_sincos_2d (line 259) | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
class SpatialCrossAttentionWithPosEmb (line 276) | class SpatialCrossAttentionWithPosEmb(nn.Module):
method __init__ (line 283) | def __init__(self, in_channels=None, heads=8, dim_head=64, dropout=0.):
method forward (line 312) | def forward(self, x, context=None):
class BasicTransformerBlock (line 361) | class BasicTransformerBlock(nn.Module):
method __init__ (line 369) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
method forward (line 380) | def forward(self, x, context=None):
method _forward (line 383) | def _forward(self, x, context=None):
class SpatialTransformer (line 390) | class SpatialTransformer(nn.Module):
method __init__ (line 399) | def __init__(self, in_channels, n_heads, d_head,
method forward (line 423) | def forward(self, x, context=None):
FILE: ldm/modules/diffusionmodules/model.py
function get_timestep_embedding (line 13) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 34) | def nonlinearity(x):
function Normalize (line 39) | def Normalize(in_channels, num_groups=32):
class IdentityWrapper (line 43) | class IdentityWrapper(nn.Module):
method __init__ (line 47) | def __init__(self) -> None:
method forward (line 51) | def forward(self, x, context=None):
class Upsample (line 56) | class Upsample(nn.Module):
method __init__ (line 57) | def __init__(self, in_channels, with_conv):
method forward (line 67) | def forward(self, x):
class Downsample (line 74) | class Downsample(nn.Module):
method __init__ (line 75) | def __init__(self, in_channels, with_conv):
method forward (line 86) | def forward(self, x):
class ResnetBlock (line 96) | class ResnetBlock(nn.Module):
method __init__ (line 97) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
method forward (line 135) | def forward(self, x, temb):
class LinAttnBlock (line 158) | class LinAttnBlock(LinearAttention):
method __init__ (line 160) | def __init__(self, in_channels):
class AttnBlock (line 164) | class AttnBlock(nn.Module):
method __init__ (line 165) | def __init__(self, in_channels):
method forward (line 192) | def forward(self, x):
function make_attn (line 219) | def make_attn(in_channels, attn_type="vanilla"):
class FIEncoder (line 234) | class FIEncoder(nn.Module):
method __init__ (line 235) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 300) | def forward(self, x, ret_feature=False):
class FlowEncoder (line 333) | class FlowEncoder(FIEncoder):
method __init__ (line 334) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks...
class FlowDecoderWithResidual (line 354) | class FlowDecoderWithResidual(nn.Module):
method __init__ (line 355) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
method forward (line 520) | def forward(self, z, cond_dict):
FILE: ldm/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 26) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 29) | def convert_module_to_f32(x):
class AttentionPool2d (line 34) | class AttentionPool2d(nn.Module):
method __init__ (line 39) | def __init__(
method forward (line 53) | def forward(self, x):
class TimestepBlock (line 64) | class TimestepBlock(nn.Module):
method forward (line 70) | def forward(self, x, emb):
class TimestepEmbedSequential (line 76) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 82) | def forward(self, x, emb, context=None):
class Upsample (line 93) | class Upsample(nn.Module):
method __init__ (line 102) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 111) | def forward(self, x):
class TransposedUpsample (line 123) | class TransposedUpsample(nn.Module):
method __init__ (line 125) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 132) | def forward(self,x):
class Downsample (line 136) | class Downsample(nn.Module):
method __init__ (line 145) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
method forward (line 160) | def forward(self, x):
class ResBlock (line 165) | class ResBlock(TimestepBlock):
method __init__ (line 181) | def __init__(
method forward (line 245) | def forward(self, x, emb):
method _forward (line 257) | def _forward(self, x, emb):
class AttentionBlock (line 280) | class AttentionBlock(nn.Module):
method __init__ (line 287) | def __init__(
method forward (line 316) | def forward(self, x):
method _forward (line 320) | def _forward(self, x):
function count_flops_attn (line 329) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 349) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 354) | def __init__(self, n_heads):
method forward (line 358) | def forward(self, qkv):
method count_flops (line 377) | def count_flops(model, _x, y):
class QKVAttention (line 381) | class QKVAttention(nn.Module):
method __init__ (line 386) | def __init__(self, n_heads):
method forward (line 390) | def forward(self, qkv):
method count_flops (line 411) | def count_flops(model, _x, y):
class UNetModel (line 415) | class UNetModel(nn.Module):
method __init__ (line 445) | def __init__(
method convert_to_fp16 (line 711) | def convert_to_fp16(self):
method convert_to_fp32 (line 719) | def convert_to_fp32(self):
method forward (line 727) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
class EncoderUNetModel (line 762) | class EncoderUNetModel(nn.Module):
method __init__ (line 768) | def __init__(
method convert_to_fp16 (line 941) | def convert_to_fp16(self):
method convert_to_fp32 (line 948) | def convert_to_fp32(self):
method forward (line 955) | def forward(self, x, timesteps):
FILE: ldm/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
method forward (line 121) | def forward(ctx, run_function, length, *args):
method backward (line 131) | def backward(ctx, *output_grads):
function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 174) | def zero_module(module):
function scale_module (line 183) | def scale_module(module, scale):
function mean_flat (line 192) | def mean_flat(tensor):
function normalization (line 199) | def normalization(channels):
class SiLU (line 209) | class SiLU(nn.Module):
method forward (line 210) | def forward(self, x):
class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
method forward (line 215) | def forward(self, x):
function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
function linear (line 231) | def linear(*args, **kwargs):
function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 258) | def forward(self, c_concat, c_crossattn):
function noise_like (line 264) | def noise_like(shape, device, repeat=False):
FILE: ldm/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 25) | def forward(self,model):
method copy_to (line 46) | def copy_to(self, model):
method store (line 55) | def store(self, parameters):
method restore (line 64) | def restore(self, parameters):
FILE: ldm/modules/losses/vqperceptual.py
function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
function l1 (line 35) | def l1(x, y):
function l2 (line 39) | def l2(x, y):
class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
FILE: ldm/modules/maxvit.py
function exists (line 12) | def exists(val):
function default (line 16) | def default(val, d):
class PreNormResidual (line 22) | class PreNormResidual(nn.Module):
method __init__ (line 23) | def __init__(self, dim, fn):
method forward (line 28) | def forward(self, x, c=None):
class SqueezeExcitation (line 34) | class SqueezeExcitation(nn.Module):
method __init__ (line 35) | def __init__(self, dim, shrinkage_rate = 0.25):
method forward (line 48) | def forward(self, x):
class FeedForward (line 52) | class FeedForward(nn.Module):
method __init__ (line 53) | def __init__(self, dim, mult = 4, dropout = 0.):
method forward (line 63) | def forward(self, x):
class Attention (line 67) | class Attention(nn.Module):
method __init__ (line 68) | def __init__(
method forward (line 108) | def forward(self, x, c=None):
class Dropsample (line 158) | class Dropsample(nn.Module):
method __init__ (line 159) | def __init__(self, prob = 0):
method forward (line 163) | def forward(self, x):
class MBConvResidual (line 173) | class MBConvResidual(nn.Module):
method __init__ (line 174) | def __init__(self, fn, dropout = 0.):
method forward (line 179) | def forward(self, x):
function MBConv (line 185) | def MBConv(
class MaxAttentionBlock (line 215) | class MaxAttentionBlock(nn.Module):
method __init__ (line 216) | def __init__(self, in_channels, heads=8, dim_head=64, dropout=0., wind...
method forward (line 232) | def forward(self, x):
class SpatialCrossAttentionWithMax (line 249) | class SpatialCrossAttentionWithMax(nn.Module):
method __init__ (line 250) | def __init__(self, in_channels, heads=8, dim_head=64, ctx_dim=None, dr...
method forward (line 274) | def forward(self, x, context=None):
class SpatialTransformerWithMax (line 298) | class SpatialTransformerWithMax(nn.Module):
method __init__ (line 307) | def __init__(self, in_channels, n_heads, d_head, dropout=0., context_d...
method forward (line 325) | def forward(self, x, context=None):
FILE: ldm/util.py
function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 41) | def ismap(x):
function isimage (line 47) | def isimage(x):
function exists (line 53) | def exists(x):
function default (line 57) | def default(val, d):
function mean_flat (line 63) | def mean_flat(tensor):
function count_params (line 71) | def count_params(model, verbose=False):
function instantiate_from_config (line 78) | def instantiate_from_config(config):
function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 108) | def parallel_data_prefetch(
FILE: main.py
function get_parser (line 23) | def get_parser(**parser_kwargs):
function nondefault_trainer_args (line 125) | def nondefault_trainer_args(opt):
class WrappedDataset (line 132) | class WrappedDataset(Dataset):
method __init__ (line 135) | def __init__(self, dataset):
method __len__ (line 138) | def __len__(self):
method __getitem__ (line 141) | def __getitem__(self, idx):
function worker_init_fn (line 145) | def worker_init_fn(_):
class DataModuleFromConfig (line 154) | class DataModuleFromConfig(pl.LightningDataModule):
method __init__ (line 155) | def __init__(self, batch_size, train=None, validation=None, test=None,...
method prepare_data (line 177) | def prepare_data(self):
method setup (line 181) | def setup(self, stage=None):
method _train_dataloader (line 189) | def _train_dataloader(self):
method _val_dataloader (line 198) | def _val_dataloader(self, shuffle=False):
method _test_dataloader (line 209) | def _test_dataloader(self, shuffle=False):
method _predict_dataloader (line 218) | def _predict_dataloader(self, shuffle=False):
class SetupCallback (line 227) | class SetupCallback(Callback):
method __init__ (line 228) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
method on_keyboard_interrupt (line 238) | def on_keyboard_interrupt(self, trainer, pl_module):
method on_fit_start (line 244) | def on_fit_start(self, trainer, pl_module):
class ImageLogger (line 276) | class ImageLogger(Callback):
method __init__ (line 277) | def __init__(self, batch_frequency, val_batch_frequency, max_images, c...
method _testtube (line 299) | def _testtube(self, pl_module, images, batch_idx, split):
method log_local (line 310) | def log_local(self, save_dir, split, images,
method log_img (line 329) | def log_img(self, pl_module, batch, batch_idx, split="train"):
method check_frequency (line 382) | def check_frequency(self, check_idx):
method on_train_batch_end (line 393) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
method on_validation_batch_end (line 397) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
method on_validation_epoch_end (line 404) | def on_validation_epoch_end(self, trainer, pl_module):
class CUDACallback (line 410) | class CUDACallback(Callback):
method on_train_epoch_start (line 412) | def on_train_epoch_start(self, trainer, pl_module):
method on_train_epoch_end (line 418) | def on_train_epoch_end(self, trainer, pl_module, outputs=None):
function melk (line 681) | def melk(*args, **kwargs):
function divein (line 689) | def divein(*args, **kwargs):
FILE: metrics/flolpips/correlation/correlation.py
function cupy_kernel (line 235) | def cupy_kernel(strFunction, objVariables):
function cupy_launch (line 274) | def cupy_launch(strFunction, strKernel):
class _FunctionCorrelation (line 278) | class _FunctionCorrelation(torch.autograd.Function):
method forward (line 280) | def forward(self, first, second):
method backward (line 333) | def backward(self, gradOutput):
function FunctionCorrelation (line 385) | def FunctionCorrelation(tenFirst, tenSecond):
class ModuleCorrelation (line 389) | class ModuleCorrelation(torch.nn.Module):
method __init__ (line 390) | def __init__(self):
method forward (line 394) | def forward(self, tenFirst, tenSecond):
FILE: metrics/flolpips/flolpips.py
function spatial_average (line 17) | def spatial_average(in_tens, keepdim=True):
function mw_spatial_average (line 20) | def mw_spatial_average(in_tens, flow, keepdim=True):
function mtw_spatial_average (line 28) | def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
function m2w_spatial_average (line 41) | def m2w_spatial_average(in_tens, flow, keepdim=True):
function upsample (line 48) | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same fo...
class LPIPS (line 53) | class LPIPS(nn.Module):
method __init__ (line 54) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
method forward (line 111) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
class ScalingLayer (line 157) | class ScalingLayer(nn.Module):
method __init__ (line 158) | def __init__(self):
method forward (line 163) | def forward(self, inp):
class NetLinLayer (line 167) | class NetLinLayer(nn.Module):
method __init__ (line 169) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
method forward (line 176) | def forward(self, x):
class Dist2LogitLayer (line 179) | class Dist2LogitLayer(nn.Module):
method __init__ (line 181) | def __init__(self, chn_mid=32, use_sigmoid=True):
method forward (line 193) | def forward(self,d0,d1,eps=0.1):
class BCERankingLoss (line 196) | class BCERankingLoss(nn.Module):
method __init__ (line 197) | def __init__(self, chn_mid=32):
method forward (line 203) | def forward(self, d0, d1, judge):
class FakeNet (line 209) | class FakeNet(nn.Module):
method __init__ (line 210) | def __init__(self, use_gpu=True, colorspace='Lab'):
class L2 (line 215) | class L2(FakeNet):
method forward (line 216) | def forward(self, in0, in1, retPerLayer=None):
class DSSIM (line 231) | class DSSIM(FakeNet):
method forward (line 233) | def forward(self, in0, in1, retPerLayer=None):
function print_network (line 246) | def print_network(net):
class FloLPIPS (line 254) | class FloLPIPS(LPIPS):
method __init__ (line 255) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
method forward (line 258) | def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
function calc_flolpips (line 278) | def calc_flolpips(dis_path, ref_path):
FILE: metrics/flolpips/pretrained_networks.py
class squeezenet (line 5) | class squeezenet(torch.nn.Module):
method __init__ (line 6) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 35) | def forward(self, X):
class alexnet (line 56) | class alexnet(torch.nn.Module):
method __init__ (line 57) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 80) | def forward(self, X):
class vgg16 (line 96) | class vgg16(torch.nn.Module):
method __init__ (line 97) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 120) | def forward(self, X):
class resnet (line 138) | class resnet(torch.nn.Module):
method __init__ (line 139) | def __init__(self, requires_grad=False, pretrained=True, num=18):
method forward (line 162) | def forward(self, X):
FILE: metrics/flolpips/pwcnet.py
function backwarp (line 45) | def backwarp(tenInput, tenFlow):
class Network (line 71) | class Network(torch.nn.Module):
method __init__ (line 72) | def __init__(self):
method forward (line 263) | def forward(self, tenFirst, tenSecond, *args):
method extract_pyramid_single (line 295) | def extract_pyramid_single(self, tenFirst):
function estimate (line 310) | def estimate(tenFirst, tenSecond):
FILE: metrics/flolpips/utils.py
function normalize_tensor (line 6) | def normalize_tensor(in_feat,eps=1e-10):
function l2 (line 10) | def l2(p0, p1, range=255.):
function dssim (line 13) | def dssim(p0, p1, range=255.):
function tensor2im (line 17) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
function tensor2np (line 22) | def tensor2np(tensor_obj):
function np2tensor (line 26) | def np2tensor(np_obj):
function tensor2tensorlab (line 30) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
function read_frame_yuv2rgb (line 44) | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt...
FILE: metrics/lpips/__init__.py
function normalize_tensor (line 42) | def normalize_tensor(in_feat,eps=1e-10):
function l2 (line 46) | def l2(p0, p1, range=255.):
function psnr (line 49) | def psnr(p0, p1, peak=255.):
function dssim (line 52) | def dssim(p0, p1, range=255.):
function rgb2lab (line 56) | def rgb2lab(in_img,mean_cent=False):
function tensor2np (line 63) | def tensor2np(tensor_obj):
function np2tensor (line 67) | def np2tensor(np_obj):
function tensor2tensorlab (line 71) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
function tensorlab2tensor (line 85) | def tensorlab2tensor(lab_tensor,return_inbnd=False):
function load_image (line 103) | def load_image(path):
function rgb2lab (line 116) | def rgb2lab(input):
function tensor2im (line 120) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
function im2tensor (line 125) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
function tensor2vec (line 129) | def tensor2vec(vector_tensor):
function tensor2im (line 133) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
function im2tensor (line 139) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
function voc_ap (line 146) | def voc_ap(rec, prec, use_07_metric=False):
FILE: metrics/lpips/lpips.py
function spatial_average (line 14) | def spatial_average(in_tens, keepdim=True):
function upsample (line 22) | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same fo...
class LPIPS (line 27) | class LPIPS(nn.Module):
method __init__ (line 28) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
method forward (line 85) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
class ScalingLayer (line 132) | class ScalingLayer(nn.Module):
method __init__ (line 133) | def __init__(self):
method forward (line 138) | def forward(self, inp):
class NetLinLayer (line 142) | class NetLinLayer(nn.Module):
method __init__ (line 144) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
method forward (line 151) | def forward(self, x):
class Dist2LogitLayer (line 154) | class Dist2LogitLayer(nn.Module):
method __init__ (line 156) | def __init__(self, chn_mid=32, use_sigmoid=True):
method forward (line 168) | def forward(self,d0,d1,eps=0.1):
class BCERankingLoss (line 171) | class BCERankingLoss(nn.Module):
method __init__ (line 172) | def __init__(self, chn_mid=32):
method forward (line 178) | def forward(self, d0, d1, judge):
class FakeNet (line 184) | class FakeNet(nn.Module):
method __init__ (line 185) | def __init__(self, use_gpu=True, colorspace='Lab'):
class L2 (line 190) | class L2(FakeNet):
method forward (line 191) | def forward(self, in0, in1, retPerLayer=None):
class DSSIM (line 206) | class DSSIM(FakeNet):
method forward (line 208) | def forward(self, in0, in1, retPerLayer=None):
function print_network (line 221) | def print_network(net):
FILE: metrics/lpips/pretrained_networks.py
class squeezenet (line 5) | class squeezenet(torch.nn.Module):
method __init__ (line 6) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 35) | def forward(self, X):
class alexnet (line 56) | class alexnet(torch.nn.Module):
method __init__ (line 57) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 80) | def forward(self, X):
class vgg16 (line 96) | class vgg16(torch.nn.Module):
method __init__ (line 97) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 120) | def forward(self, X):
class resnet (line 138) | class resnet(torch.nn.Module):
method __init__ (line 139) | def __init__(self, requires_grad=False, pretrained=True, num=18):
method forward (line 162) | def forward(self, X):
FILE: metrics/pytorch_ssim/__init__.py
function gaussian (line 7) | def gaussian(window_size, sigma):
function create_window (line 11) | def create_window(window_size, channel):
function create_window_3d (line 18) | def create_window_3d(window_size, channel=1):
function _ssim (line 26) | def _ssim(img1, img2, window, window_size, channel, size_average = True):
function ssim_matlab (line 49) | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=Tr...
class SSIM (line 105) | class SSIM(torch.nn.Module):
method __init__ (line 106) | def __init__(self, window_size = 11, size_average = True):
method forward (line 113) | def forward(self, img1, img2):
function ssim (line 131) | def ssim(img1, img2, window_size = 11, size_average = True):
FILE: utility.py
function read_frame_yuv2rgb (line 9) | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt...
function CharbonnierFunc (line 65) | def CharbonnierFunc(data, epsilon=0.001):
function moduleNormalize (line 69) | def moduleNormalize(frame):
function gaussian_kernel (line 73) | def gaussian_kernel(sz, sigma):
function quantize (line 81) | def quantize(imTensor):
function tensor2rgb (line 85) | def tensor2rgb(tensor):
function calc_psnr (line 95) | def calc_psnr(gt, out, *args):
function calc_ssim (line 104) | def calc_ssim(gt, out, *args):
function calc_lpips (line 108) | def calc_lpips(gt, out, *args):
function calc_flolpips (line 114) | def calc_flolpips(gt_list, out_list, inputs_list):
Condensed preview — 47 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (453K chars).
[
{
"path": ".gitignore",
"chars": 101,
"preview": "*.pth\n*.ckpt\n*__pycache__*\n*.pyc\n*egg*\n*src/*\n*.ipynb\nlogs/*\n*delete*\neval_results*\n*.idea*\n*.pytorch"
},
{
"path": "LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2023 danielism97\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "README.md",
"chars": 5798,
"preview": "# LDMVFI: Video Frame Interpolation with Latent Diffusion Models\n\n[**Duolikun Danier**](https://danier97.github.io/), [*"
},
{
"path": "configs/autoencoder/vqflow-f32.yaml",
"chars": 1472,
"preview": "model:\n base_learning_rate: 1.0e-5\n target: ldm.models.autoencoder.VQFlowNet\n params:\n monitor: val/total_loss\n "
},
{
"path": "configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml",
"chars": 2659,
"preview": "model:\n base_learning_rate: 1.0e-06\n target: ldm.models.diffusion.ddpm.LatentDiffusionVFI\n params:\n linear_start: "
},
{
"path": "cupy_module/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cupy_module/dsepconv.py",
"chars": 32355,
"preview": "import torch\n\nimport cupy\nimport re\n\n\nclass Stream:\n ptr = torch.cuda.current_stream().cuda_stream\n\n\n# end\n\nkernel_DS"
},
{
"path": "environment.yaml",
"chars": 1216,
"preview": "name: ldmvfi\nchannels:\n - pytorch\n - defaults\n - conda-forge\ndependencies:\n - python=3.9.13\n - pytorch=1.11.0\n - t"
},
{
"path": "evaluate.py",
"chars": 2387,
"preview": "import argparse\nimport os\nimport torch\nfrom functools import partial\nfrom omegaconf import OmegaConf\nfrom main import in"
},
{
"path": "evaluate_vqm.py",
"chars": 1386,
"preview": "import argparse\nimport os\nfrom ldm.data import testsets_vqm\n\n\nparser = argparse.ArgumentParser(description='Frame Interp"
},
{
"path": "interpolate_yuv.py",
"chars": 5038,
"preview": "import argparse\nimport torch\nimport torchvision.transforms.functional as TF\nimport os\nfrom PIL import Image\nfrom tqdm im"
},
{
"path": "ldm/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/data/bvi_vimeo.py",
"chars": 9439,
"preview": "import numpy as np\nimport random\nfrom os import listdir\nfrom os.path import join, isdir, split, getsize\nfrom torch.utils"
},
{
"path": "ldm/data/testsets.py",
"chars": 20186,
"preview": "import glob\nfrom typing import List\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms\nimport torchvi"
},
{
"path": "ldm/data/testsets_vqm.py",
"chars": 16418,
"preview": "import glob\nfrom typing import List\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms\nimport torchvi"
},
{
"path": "ldm/data/vfitransforms.py",
"chars": 674,
"preview": "import random\nimport torch\nimport torchvision\nimport torchvision.transforms.functional as TF\n\n\ndef rand_crop(*args, sz):"
},
{
"path": "ldm/lr_scheduler.py",
"chars": 3882,
"preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n \"\"\"\n note: use with a base_lr of 1.0\n \"\"\"\n def __in"
},
{
"path": "ldm/models/autoencoder.py",
"chars": 15340,
"preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom torch.optim.lr_scheduler import LambdaL"
},
{
"path": "ldm/models/diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/models/diffusion/ddim.py",
"chars": 11063,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
},
{
"path": "ldm/models/diffusion/ddpm.py",
"chars": 74715,
"preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
},
{
"path": "ldm/modules/attention.py",
"chars": 14950,
"preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
},
{
"path": "ldm/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/diffusionmodules/model.py",
"chars": 23522,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom ld"
},
{
"path": "ldm/modules/diffusionmodules/openaimodel.py",
"chars": 36432,
"preview": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np"
},
{
"path": "ldm/modules/diffusionmodules/util.py",
"chars": 9561,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "ldm/modules/ema.py",
"chars": 2982,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "ldm/modules/losses/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "ldm/modules/losses/vqperceptual.py",
"chars": 7838,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
},
{
"path": "ldm/modules/maxvit.py",
"chars": 11799,
"preview": "import torch\nfrom torch import nn, einsum\nimport torch.nn.functional\nfrom einops import rearrange, repeat\nfrom einops.la"
},
{
"path": "ldm/util.py",
"chars": 5857,
"preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
},
{
"path": "main.py",
"chars": 27803,
"preview": "import argparse, os, sys, datetime, glob\nimport numpy as np\nimport time\nimport torch\nimport torchvision\nimport pytorch_l"
},
{
"path": "metrics/flolpips/.gitignore",
"chars": 30,
"preview": "*__pycache__*\n*.ipynb\n*delete*"
},
{
"path": "metrics/flolpips/LICENSE",
"chars": 1068,
"preview": "MIT License\n\nCopyright (c) 2022 danielism97\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "metrics/flolpips/README.md",
"chars": 1145,
"preview": "# FloLPIPS: A bespoke video quality metric for frame interpoation\n\n### Duolikun Danier, Fan Zhang, David Bull\n\n\n[Project"
},
{
"path": "metrics/flolpips/__init__.py",
"chars": 24,
"preview": "from .flolpips import *\n"
},
{
"path": "metrics/flolpips/correlation/correlation.py",
"chars": 13604,
"preview": "#!/usr/bin/env python\n\nimport torch\n\nimport cupy\nimport re\n\nkernel_Correlation_rearrange = '''\n\textern \"C\" __global__ vo"
},
{
"path": "metrics/flolpips/flolpips.py",
"chars": 15089,
"preview": "\nfrom __future__ import absolute_import\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.autog"
},
{
"path": "metrics/flolpips/pretrained_networks.py",
"chars": 6507,
"preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):"
},
{
"path": "metrics/flolpips/pwcnet.py",
"chars": 16053,
"preview": "#!/usr/bin/env python\n\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport os\nimport PIL\nimport PIL.Image\nimport"
},
{
"path": "metrics/flolpips/utils.py",
"chars": 3495,
"preview": "import numpy as np\nimport cv2\nimport torch\n\n\ndef normalize_tensor(in_feat,eps=1e-10):\n norm_factor = torch.sqrt(torch"
},
{
"path": "metrics/lpips/__init__.py",
"chars": 6170,
"preview": "\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport nu"
},
{
"path": "metrics/lpips/lpips.py",
"chars": 9390,
"preview": "\nfrom __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.auto"
},
{
"path": "metrics/lpips/pretrained_networks.py",
"chars": 6507,
"preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):"
},
{
"path": "metrics/pytorch_ssim/__init__.py",
"chars": 4844,
"preview": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp"
},
{
"path": "setup.py",
"chars": 233,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name='latent-diffusion',\n version='0.0.1',\n description=''"
},
{
"path": "utility.py",
"chars": 5326,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport cv2\nfrom metrics import pytorch_ssim, lpips, flol"
}
]
About this extraction
This page contains the full source code of the danier97/LDMVFI GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 47 files (425.2 KB), approximately 111.6k tokens, and a symbol index with 577 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.