Full Code of danier97/LDMVFI for AI

main eee2dc3566f2 cached
47 files
425.2 KB
111.6k tokens
577 symbols
1 requests
Download .txt
Showing preview only (444K chars total). Download the full file or copy to clipboard to get everything.
Repository: danier97/LDMVFI
Branch: main
Commit: eee2dc3566f2
Files: 47
Total size: 425.2 KB

Directory structure:
gitextract_grs9wzig/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── autoencoder/
│   │   └── vqflow-f32.yaml
│   └── ldm/
│       └── ldmvfi-vqflow-f32-c256-concat_max.yaml
├── cupy_module/
│   ├── __init__.py
│   └── dsepconv.py
├── environment.yaml
├── evaluate.py
├── evaluate_vqm.py
├── interpolate_yuv.py
├── ldm/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── bvi_vimeo.py
│   │   ├── testsets.py
│   │   ├── testsets_vqm.py
│   │   └── vfitransforms.py
│   ├── lr_scheduler.py
│   ├── models/
│   │   ├── autoencoder.py
│   │   └── diffusion/
│   │       ├── __init__.py
│   │       ├── ddim.py
│   │       └── ddpm.py
│   ├── modules/
│   │   ├── attention.py
│   │   ├── diffusionmodules/
│   │   │   ├── __init__.py
│   │   │   ├── model.py
│   │   │   ├── openaimodel.py
│   │   │   └── util.py
│   │   ├── ema.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   └── vqperceptual.py
│   │   └── maxvit.py
│   └── util.py
├── main.py
├── metrics/
│   ├── flolpips/
│   │   ├── .gitignore
│   │   ├── LICENSE
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── correlation/
│   │   │   └── correlation.py
│   │   ├── flolpips.py
│   │   ├── pretrained_networks.py
│   │   ├── pwcnet.py
│   │   └── utils.py
│   ├── lpips/
│   │   ├── __init__.py
│   │   ├── lpips.py
│   │   └── pretrained_networks.py
│   └── pytorch_ssim/
│       └── __init__.py
├── setup.py
└── utility.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pth
*.ckpt
*__pycache__*
*.pyc
*egg*
*src/*
*.ipynb
logs/*
*delete*
eval_results*
*.idea*
*.pytorch

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 danielism97

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# LDMVFI: Video Frame Interpolation with Latent Diffusion Models

[**Duolikun Danier**](https://danier97.github.io/), [**Fan Zhang**](https://fan-aaron-zhang.github.io/), [**David Bull**](https://david-bull.github.io/)

[Project](TODO) | [arXiv](https://arxiv.org/abs/2303.09508) | [Video](https://drive.google.com/file/d/1oL6j_l3b2QEqsL0iO7qSZrGUXJaTpRWN/view?usp=share_link)

![Demo gif](assets/ldmvfi.gif)


## Overview
We observe that most existing learning-based VFI models are trained to minimise the L1/L2/VGG loss between their outputs and the ground-truth frames. However, it was shown in previous works that these metrics do not correlate well with the **perceptual quality** of VFI. On the other hand, generative models, especially diffusion models, are showing remarkable results in generating visual content with high perceptual quality. In this work, we leverage the high-fidelity image/video generation capabilities of **latent diffusion models** to perform generative VFI.
<p align="center">
<img src="https://danier97.github.io/LDMVFI/overall.svg" alt="Paper" width="60%">
</p>

## Dependencies and Installation
See [environment.yaml](./environment.yaml) for requirements on packages. Simple installation:
```
conda env create -f environment.yaml
```

## Pre-trained Model
The pre-trained model can be downloaded from [here](https://drive.google.com/file/d/1_Xx2fBYQT9O-6O3zjzX76O9XduGnCh_7/view?usp=share_link), and its corresponding config file is [this yaml](./configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml).


## Preparing datasets
### Training sets:
[[Vimeo-90K]](http://toflow.csail.mit.edu/) | [[BVI-DVC quintuplets]](https://drive.google.com/file/d/1i_CoqiNrZ2AU8DKjU8aHM1jIaDGW0fE5/view?usp=sharing)

### Test sets: 
[[Middlebury]](https://vision.middlebury.edu/flow/data/) | [[UCF101]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[DAVIS]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[SNU-FILM]](https://myungsub.github.io/CAIN/)


To make use of the [evaluate.py](evaluate.py) and the files in [ldm/data/](./ldm/data/), the dataset folder names should be lower-case and structured as follows.
```
└──── <data directory>/
    ├──── middlebury_others/
    |   ├──── input/
    |   |   ├──── Beanbags/
    |   |   ├──── ...
    |   |   └──── Walking/
    |   └──── gt/
    |       ├──── Beanbags/
    |       ├──── ...
    |       └──── Walking/
    ├──── ucf101/
    |   ├──── 0/
    |   ├──── ...
    |   └──── 99/
    ├──── davis90/
    |   ├──── bear/
    |   ├──── ...
    |   └──── walking/
    ├──── snufilm/
    |   ├──── test-easy.txt
    |   ├──── ...
    |   └──── data/SNU-FILM/test/...
    ├──── bvidvc/quintuplets
    |   ├──── 00000/
    |   ├──── ...
    |   └──── 17599/
    └──── vimeo_septuplet/
        ├──── sequences/
        ├──── sep_testlist.txt
        └──── sep_trainlist.txt
```

## Evaluation

To evaluate LDMVFI (with DDIM sampler), for example, on the Middlebury dataset, using PSNR/SSIM/LPIPS, run the following command.
```
python evaluate.py \
--config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \
--ckpt <path/to/ldmvfi-vqflow-f32-c256-concat_max.ckpt> \
--dataset Middlebury_others \
--metrics PSNR SSIM LPIPS \
--data_dir <path/to/data/dir> \
--out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \
--use_ddim
```
This will create the directory `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/`, and store the interpolated frames, as well as a `results.txt` file in that directory. For other test sets, replace `Middlebury_other` with the corresponding class names defined in [ldm/data/testsets.py](ldm/data/testsets.py) (e.g. `Ucf101_triplet`).

\
To evaluate the model on perceptual video metric FloLPIPS, first evaluate the image metrics using the code above (so that the interpolated frames are saved in `eval_results/ldmvfi-vqflow-f32-c256-concat_max`), then run the following code.
```
python evaluate_vqm.py \
--exp ldmvfi-vqflow-f32-c256-concat_max \
--dataset Middlebury_others \
--metrics FloLPIPS \
--data_dir <path/to/data/dir> \
--out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \
```
This will read the interpolated frames previously stored in `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/` then output the evaluation results to `results_vqm.txt` in the same folder.

\
To interpolate a video (in .yuv format), use the following code.
```
python interpolate_yuv.py \
--net LDMVFI \
--config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \
--ckpt <path/to/ldmvfi-vqflow-f32-c256-concat_max.ckpt> \
--input_yuv <path/to/input/yuv> \
--size <spatial res of video, e.g. 1920x1080> \
--out_fps <output fps, should be 2 x original fps> \
--out_dir <desired/output/dir> \
--use_ddim
```

## Training
LDMVFI is trained in two stages, where the VQ-FIGAN and the denoising U-Net are trained separately.
### VQ-FIGAN
```
python main.py --base configs/autoencoder/vqflow-f32.yaml -t --gpus 0,
```
### Denoising U-Net
```
python main.py --base configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml -t --gpus 0,
```
These will create a `logs/` folder within which the corresonding directories are created for each experiment. The log files from training include checkpoints, images and tensorboard loggings.

To resume from a checkpoint file, simply use the `--resume` argument in [main.py](main.py) to specify the checkpoint.


## Citation
```
@article{danier2023ldmvfi,
  title={LDMVFI: Video Frame Interpolation with Latent Diffusion Models},
  author={Danier, Duolikun and Zhang, Fan and Bull, David},
  journal={arXiv preprint arXiv:2303.09508},
  year={2023}
}
```

## Acknowledgement
Our code is adapted from the original [latent-diffusion](https://github.com/CompVis/latent-diffusion) repository. We thank the authors for sharing their code.

================================================
FILE: configs/autoencoder/vqflow-f32.yaml
================================================
model:
  base_learning_rate: 1.0e-5
  target: ldm.models.autoencoder.VQFlowNet
  params:
    monitor: val/total_loss
    embed_dim: 3
    n_embed: 8192
    ddconfig:
      double_z: False
      z_channels: 3
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 64
      ch_mult: [1,2,2,2,4]  # f = 2 ^ len(ch_mult)
      num_res_blocks: 1
      cond_type: max_cross_attn
      attn_type: max
      attn_resolutions: []
      dropout: 0.0

    lossconfig:
      target: ldm.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
      params:
        disc_conditional: False
        disc_in_channels: 3
        disc_start: 10000
        disc_weight: 0.8
        codebook_weight: 1.0

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 10
    num_workers: 0
    wrap: false
    train:
      target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet
      params:
        db_dir: C:/data_tmp/
        crop_sz: [256,256]
        iter: True
    validation:
      target: ldm.data.bvi_vimeo.Vimeo90k_triplet
      params:
        db_dir: C:/data_tmp/vimeo_septuplet/
        train: False
        crop_sz: [256,256]
        augment_s: False
        augment_t: False


lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 8000
        val_batch_frequency: 800
        max_images: 8
        increase_log_steps: False
        log_images_kwargs: {'N': 1}

  trainer:
    benchmark: True
    max_epochs: -1


================================================
FILE: configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml
================================================
model:
  base_learning_rate: 1.0e-06
  target: ldm.models.diffusion.ddpm.LatentDiffusionVFI
  params:
    linear_start: 0.0015
    linear_end: 0.0195
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: image
    cond_stage_key: past_future_frames
    image_size: 8
    channels: 3
    cond_stage_trainable: False
    concat_mode: True
    monitor: val/loss_simple_ema
    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 8 # img size of latent, used during training, determines some model params, so don't change for inference
        in_channels: 9
        out_channels: 3
        model_channels: 256
        attention_resolutions:
        #note: this isn\t actually the resolution but
        # the downsampling factor, i.e. this corresnponds to
        # attention on spatial resolution 8,16,32, as the
        # spatial reolution of the latents is 32 for f8
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        num_head_channels: 32
        use_max_self_attn: True # replace all full self-attention with MaxViT
    first_stage_config:
      target: ldm.models.autoencoder.VQFlowNetInterface
      params:
        ckpt_path: null # must specify pre-trained autoencoding model ckpt to train the denoising UNet
        embed_dim: 3
        n_embed: 8192
        ddconfig:
          double_z: False
          z_channels: 3
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 64
          ch_mult: [1,2,2,2,4]  # f = 2 ^ len(ch_mult)
          num_res_blocks: 1
          cond_type: max_cross_attn
          attn_type: max
          attn_resolutions: [ ]
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity
    cond_stage_config: __is_first_stage__


data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 64
    num_workers: 0
    wrap: false
    train:
      target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet
      params:
        db_dir: C:/data_tmp/
        crop_sz: [256,256]
        iter: True
    validation:
      target: ldm.data.bvi_vimeo.Vimeo90k_triplet
      params:
        db_dir: C:/data_tmp/vimeo_septuplet/
        train: False
        crop_sz: [256,256]
        augment_s: False
        augment_t: False


lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 1250
        val_batch_frequency: 125
        max_images: 8
        increase_log_steps: False
        log_images_kwargs: {'N': 1}

  trainer:
    benchmark: True
    max_epochs: -1


================================================
FILE: cupy_module/__init__.py
================================================


================================================
FILE: cupy_module/dsepconv.py
================================================
import torch

import cupy
import re


class Stream:
    ptr = torch.cuda.current_stream().cuda_stream


# end

kernel_DSepconv_updateOutput = '''
	extern "C" __global__ void kernel_DSepconv_updateOutput(
		const int n,
		const float* input,
		const float* vertical,
		const float* horizontal,
		const float* offset_x,
		const float* offset_y,
		const float* mask,
		float* output
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		float dblOutput = 0.0;

		const int intSample = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output);
		const int intDepth  = ( intIndex / SIZE_3(output) / SIZE_2(output)                  ) % SIZE_1(output);
		const int intY      = ( intIndex / SIZE_3(output)                                   ) % SIZE_2(output);
		const int intX      = ( intIndex                                                    ) % SIZE_3(output);
		

		for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1) {
			for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1) {
			    float delta_x = OFFSET_4(offset_y, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX);
			    float delta_y = OFFSET_4(offset_x, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX);
			    
			    float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1;
			    float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
			    if (position_x < 0)
			        position_x = 0;
			    if (position_x > SIZE_3(input) - 1)
			        position_x = SIZE_3(input) - 1;
			    if (position_y < 0)
			        position_y = 0;
			    if (position_y > SIZE_2(input) - 1)
			        position_y =  SIZE_2(input) - 1;
			    
			    int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1);
			    int right = left + 1;
			    if (left < 0)
			        left = 0;
			    if (left > SIZE_3(input) - 1)
			        left = SIZE_3(input) - 1;
			    if (right < 0)
			        right = 0;
			    if (right > SIZE_3(input) - 1)
			        right = SIZE_3(input) - 1;
			    
			    int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
			    int bottom = top + 1;
			    if (top < 0)
			        top = 0;
			    if (top > SIZE_2(input) - 1)
			        top =  SIZE_2(input) - 1;
			    if (bottom < 0)
			        bottom = 0;   
			    if (bottom > SIZE_2(input) - 1)
			        bottom = SIZE_2(input) - 1;
			    
			    float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			                       VALUE_4(input, intSample, intDepth, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			                       VALUE_4(input, intSample, intDepth, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			                       VALUE_4(input, intSample, intDepth, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			                       
				dblOutput += floatValue * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, SIZE_1(vertical)*intFilterY + intFilterX, intY, intX);
			}
		}
		output[intIndex] = dblOutput;
	} }
'''

kernel_DSepconv_updateGradVertical = '''
	extern "C" __global__ void kernel_DSepconv_updateGradVertical(
		const int n,
		const float* gradLoss,
		const float* input,
		const float* horizontal,
		const float* offset_x,
		const float* offset_y,
		const float* mask,
		float* gradVertical
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		float floatOutput = 0.0;

		const int intSample   = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical) / SIZE_1(gradVertical) ) % SIZE_0(gradVertical);
		const int intFilterY  = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical)                        ) % SIZE_1(gradVertical);
		const int intY        = ( intIndex / SIZE_3(gradVertical)                                               ) % SIZE_2(gradVertical);
		const int intX        = ( intIndex                                                                      ) % SIZE_3(gradVertical);

		for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1){
		    int intDepth = intFilterY * SIZE_1(horizontal) + intFilterX;
		    float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
			float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
			
			float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1;
			float position_y = delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1;
			if (position_x < 0)
			    position_x = 0;
			if (position_x > SIZE_3(input) - 1)
			    position_x = SIZE_3(input) - 1;
			if (position_y < 0)
			    position_y = 0;
			if (position_y > SIZE_2(input) - 1)
			    position_y =  SIZE_2(input) - 1;
		
			int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1);
			int right = left + 1;
			if (left < 0)
			    left = 0;
			if (left > SIZE_3(input) - 1)
			    left = SIZE_3(input) - 1;
			if (right < 0)
			    right = 0;
			if (right > SIZE_3(input) - 1)
			    right = SIZE_3(input) - 1;

			int top = floor(delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1);
			int bottom = top + 1;
			if (top < 0)
			    top = 0;
			if (top > SIZE_2(input) - 1)
			    top =  SIZE_2(input) - 1;
			if (bottom < 0)
			    bottom = 0;   
			if (bottom > SIZE_2(input) - 1)
			    bottom = SIZE_2(input) - 1;
			
			float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			
			floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
				       VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
				       VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX);
		}
		gradVertical[intIndex] = floatOutput;
	} }

'''

kernel_DSepconv_updateGradHorizontal = '''
	extern "C" __global__ void kernel_DSepconv_updateGradHorizontal(
		const int n,
		const float* gradLoss,
		const float* input,
		const float* vertical,
		const float* offset_x,
		const float* offset_y,
		const float* mask,
		float* gradHorizontal
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
		float floatOutput = 0.0;

		const int intSample   = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal) / SIZE_1(gradHorizontal) ) % SIZE_0(gradHorizontal);
		const int intFilterX  = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal)                          ) % SIZE_1(gradHorizontal);
		const int intY        = ( intIndex / SIZE_3(gradHorizontal)                                                   ) % SIZE_2(gradHorizontal);
		const int intX        = ( intIndex                                                                            ) % SIZE_3(gradHorizontal);

		for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1){
		    int intDepth = intFilterY * SIZE_1(vertical) + intFilterX;
		    float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
			float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
		
			float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
			float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
			if (position_x < 0)
			    position_x = 0;
			if (position_x > SIZE_3(input) - 1)
			    position_x = SIZE_3(input) - 1;
			if (position_y < 0)
			    position_y = 0;
			if (position_y > SIZE_2(input) - 1)
			    position_y =  SIZE_2(input) - 1;
		
			int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
			int right = left + 1;
			if (left < 0)
			    left = 0;
			if (left > SIZE_3(input) - 1)
			    left = SIZE_3(input) - 1;
			if (right < 0)
			    right = 0;
			if (right > SIZE_3(input) - 1)
			    right = SIZE_3(input) - 1;

			int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
			int bottom = top + 1;
			if (top < 0)
			    top = 0;
			if (top > SIZE_2(input) - 1)
			    top =  SIZE_2(input) - 1;
			if (bottom < 0)
			    bottom = 0;   
			if (bottom > SIZE_2(input) - 1)
			    bottom = SIZE_2(input) - 1;
			
			float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
			float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y));
				
			floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
				       VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) +
				       VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX);
		}
		gradHorizontal[intIndex] = floatOutput;
	} }
'''

kernel_DSepconv_updateGradMask = '''
	extern "C" __global__ void kernel_DSepconv_updateGradMask(
		const int n,
		const float* gradLoss,
		const float* input,
		const float* vertical,
		const float* horizontal,
		const float* offset_x,
		const float* offset_y,
		float* gradMask
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
	    float floatOutput = 0.0;

		const int intSample   = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask) / SIZE_1(gradMask) ) % SIZE_0(gradMask);
		const int intDepth    = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask)                    ) % SIZE_1(gradMask);
		const int intY        = ( intIndex / SIZE_3(gradMask)                                       ) % SIZE_2(gradMask);
		const int intX        = ( intIndex                                                          ) % SIZE_3(gradMask);
		
		int intFilterY = intDepth / SIZE_1(vertical);
        int intFilterX = intDepth % SIZE_1(vertical);
        
        float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
		float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);
		
		float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
		float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
		if (position_x < 0)
			position_x = 0;
		if (position_x > SIZE_3(input) - 1)
			position_x = SIZE_3(input) - 1;
		if (position_y < 0)
			position_y = 0;
		if (position_y > SIZE_2(input) - 1)
			position_y =  SIZE_2(input) - 1;
		
		int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
		int right = left + 1;
		if (left < 0)
			left = 0;
		if (left > SIZE_3(input) - 1)
			left = SIZE_3(input) - 1;
		if (right < 0)
			right = 0;
		if (right > SIZE_3(input) - 1)
			right = SIZE_3(input) - 1;

		int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
		int bottom = top + 1;
		if (top < 0)
			top = 0;
		if (top > SIZE_2(input) - 1)
			top =  SIZE_2(input) - 1;
		if (bottom < 0)
			bottom = 0;   
		if (bottom > SIZE_2(input) - 1)
			bottom = SIZE_2(input) - 1;
		
		for (int intChannel = 0; intChannel < 3; intChannel++){
		    floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
		                   VALUE_4(input, intSample, intChannel, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, intChannel, top, right) * (1 - (right - position_x)) *  (1 + (top - position_y)) + 
			               VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + 
			               VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y))
		                   ) * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX);
		} 
		gradMask[intIndex] = floatOutput;
	} }
'''

kernel_DSepconv_updateGradOffsetX = '''
	extern "C" __global__ void kernel_DSepconv_updateGradOffsetX(
		const int n,
		const float* gradLoss,
		const float* input,
		const float* vertical,
		const float* horizontal,
		const float* offset_x,
		const float* offset_y,
		const float* mask,
		float* gradOffsetX
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
	    float floatOutput = 0.0;

		const int intSample   = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX);
		const int intDepth    = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX)                       ) % SIZE_1(gradOffsetX);
		const int intY        = ( intIndex / SIZE_3(gradOffsetX)                                             ) % SIZE_2(gradOffsetX);
		const int intX        = ( intIndex                                                                   ) % SIZE_3(gradOffsetX);

		int intFilterY = intDepth / SIZE_1(vertical);
        int intFilterX = intDepth % SIZE_1(vertical);

        float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
		float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);

		float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
		float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
		if (position_x < 0)
			position_x = 0;
		if (position_x > SIZE_3(input) - 1)
			position_x = SIZE_3(input) - 1;
		if (position_y < 0)
			position_y = 0;
		if (position_y > SIZE_2(input) - 1)
			position_y =  SIZE_2(input) - 1;
		
		int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
		int right = left + 1;
		if (left < 0)
			left = 0;
		if (left > SIZE_3(input) - 1)
			left = SIZE_3(input) - 1;
		if (right < 0)
			right = 0;
		if (right > SIZE_3(input) - 1)
			right = SIZE_3(input) - 1;

		int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
		int bottom = top + 1;
		if (top < 0)
			top = 0;
		if (top > SIZE_2(input) - 1)
			top =  SIZE_2(input) - 1;
		if (bottom < 0)
			bottom = 0;   
		if (bottom > SIZE_2(input) - 1)
			bottom = SIZE_2(input) - 1;

		for (int intChannel = 0; intChannel < 3; intChannel++){
			floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
		                   - VALUE_4(input, intSample, intChannel, top, left)  * (1 + (left - position_x))
		                   - VALUE_4(input, intSample, intChannel, top, right)  *  (1 - (right - position_x))
			               + VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x))
			               + VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x))
			               )
		                   * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX)
		                   * VALUE_4(mask, intSample, intDepth, intY, intX);
		} 
		gradOffsetX[intIndex] = floatOutput;
	} }
'''

kernel_DSepconv_updateGradOffsetY = '''
	extern "C" __global__ void kernel_DSepconv_updateGradOffsetY(
		const int n,
		const float* gradLoss,
		const float* input,
		const float* vertical,
		const float* horizontal,
		const float* offset_x,
		const float* offset_y,
		const float* mask,
		float* gradOffsetY
	) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
	    float floatOutput = 0.0;

		const int intSample   = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX);
		const int intDepth    = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX)                       ) % SIZE_1(gradOffsetX);
		const int intY        = ( intIndex / SIZE_3(gradOffsetX)                                             ) % SIZE_2(gradOffsetX);
		const int intX        = ( intIndex                                                                   ) % SIZE_3(gradOffsetX);

		int intFilterY = intDepth / SIZE_1(vertical);
        int intFilterX = intDepth % SIZE_1(vertical);

        float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX);
		float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX);

		float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1;
		float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1;
		if (position_x < 0)
			position_x = 0;
		if (position_x > SIZE_3(input) - 1)
			position_x = SIZE_3(input) - 1;
		if (position_y < 0)
			position_y = 0;
		if (position_y > SIZE_2(input) - 1)
			position_y =  SIZE_2(input) - 1;
		
		int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1);
		int right = left + 1;
		if (left < 0)
			left = 0;
		if (left > SIZE_3(input) - 1)
			left = SIZE_3(input) - 1;
		if (right < 0)
			right = 0;
		if (right > SIZE_3(input) - 1)
			right = SIZE_3(input) - 1;

		int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1);
		int bottom = top + 1;
		if (top < 0)
			top = 0;
		if (top > SIZE_2(input) - 1)
			top =  SIZE_2(input) - 1;
		if (bottom < 0)
			bottom = 0;   
		if (bottom > SIZE_2(input) - 1)
			bottom = SIZE_2(input) - 1;

		for (int intChannel = 0; intChannel < 3; intChannel++){
		    floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * (
		                   - VALUE_4(input, intSample, intChannel, top, left)  * (1 + (top - position_y)) 
		                   + VALUE_4(input, intSample, intChannel, top, right)  *  (1 + (top - position_y)) 
			               - VALUE_4(input, intSample, intChannel, bottom, left) * (1 - (bottom - position_y)) 
			               + VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (bottom - position_y))
			               )
		                   * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX)
		                   * VALUE_4(mask, intSample, intDepth, intY, intX);
		} 
		gradOffsetY[intIndex] = floatOutput;
	} }
'''


def cupy_kernel(strFunction, objectVariables):
    strKernel = globals()[strFunction]

    while True:
        objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)

        if objectMatch is None:
            break
        # end

        intArg = int(objectMatch.group(2))

        strTensor = objectMatch.group(4)
        intSizes = objectVariables[strTensor].size()

        strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg]))
    # end

    while True:
        objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)

        if objectMatch is None:
            break
        # end

        intArgs = int(objectMatch.group(2))
        strArgs = objectMatch.group(4).split(',')

        strTensor = strArgs[0]
        intStrides = objectVariables[strTensor].stride()
        strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
            intStrides[intArg]) + ')' for intArg in range(intArgs)]

        strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
    # end

    while True:
        objectMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel)

        if objectMatch is None:
            break
        # end

        intArgs = int(objectMatch.group(2))
        strArgs = objectMatch.group(4).split(',')

        strTensor = strArgs[0]
        intStrides = objectVariables[strTensor].stride()
        strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(
            intStrides[intArg]) + ')' for intArg in range(intArgs)]

        strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
    # end

    return strKernel


# end

@cupy.memoize(for_each_device=True)
def cupy_launch(strFunction, strKernel):
    # return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
    return cupy.RawKernel(strKernel, strFunction)


# end

class _FunctionDSepconv(torch.autograd.Function):
    @staticmethod
    def forward(self, input, vertical, horizontal, offset_x, offset_y, mask):
        self.save_for_backward(input, vertical, horizontal, offset_x, offset_y, mask)

        intSample = input.size(0)
        intInputDepth = input.size(1)
        intInputHeight = input.size(2)
        intInputWidth = input.size(3)
        intFilterSize = min(vertical.size(1), horizontal.size(1))
        intOutputHeight = min(vertical.size(2), horizontal.size(2))
        intOutputWidth = min(vertical.size(3), horizontal.size(3))

        assert (intInputHeight == intOutputHeight + intFilterSize - 1)
        assert (intInputWidth == intOutputWidth + intFilterSize - 1)

        assert (input.is_contiguous() == True)
        assert (vertical.is_contiguous() == True)
        assert (horizontal.is_contiguous() == True)
        assert (offset_x.is_contiguous() == True)
        assert (offset_y.is_contiguous() == True)
        assert (mask.is_contiguous() == True)

        output = input.new_zeros([intSample, intInputDepth, intOutputHeight, intOutputWidth])

        if input.is_cuda == True:
            n = output.nelement()
            cupy_launch('kernel_DSepconv_updateOutput', cupy_kernel('kernel_DSepconv_updateOutput', {
                'input': input,
                'vertical': vertical,
                'horizontal': horizontal,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'mask': mask,
                'output': output
            }))(
                grid=tuple([int((n + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[n, input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(),
                      mask.data_ptr(), output.data_ptr()],
                stream=Stream
            )

        elif input.is_cuda == False:
            raise NotImplementedError()

        # end

        return output

    # end

    @staticmethod
    def backward(self, gradOutput):
        input, vertical, horizontal, offset_x, offset_y, mask = self.saved_tensors

        intSample = input.size(0)
        intInputDepth = input.size(1)
        intInputHeight = input.size(2)
        intInputWidth = input.size(3)
        intFilterSize = min(vertical.size(1), horizontal.size(1))
        intOutputHeight = min(vertical.size(2), horizontal.size(2))
        intOutputWidth = min(vertical.size(3), horizontal.size(3))

        assert (intInputHeight == intOutputHeight + intFilterSize - 1)
        assert (intInputWidth == intOutputWidth + intFilterSize - 1)

        assert (gradOutput.is_contiguous() == True)

        gradInput = input.new_zeros([intSample, intInputDepth, intInputHeight, intInputWidth]) if \
            self.needs_input_grad[0] == True else None
        gradVertical = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \
            self.needs_input_grad[1] == True else None
        gradHorizontal = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \
            self.needs_input_grad[2] == True else None
        gradOffsetX = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
            self.needs_input_grad[3] == True else None
        gradOffsetY = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
            self.needs_input_grad[4] == True else None
        gradMask = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \
            self.needs_input_grad[5] == True else None

        if input.is_cuda == True:
            nv = gradVertical.nelement()
            cupy_launch('kernel_DSepconv_updateGradVertical', cupy_kernel('kernel_DSepconv_updateGradVertical', {
                'gradLoss': gradOutput,
                'input': input,
                'horizontal': horizontal,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'mask': mask,
                'gradVertical': gradVertical
            }))(
                grid=tuple([int((nv + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[nv, gradOutput.data_ptr(), input.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(),
                      offset_y.data_ptr(), mask.data_ptr(), gradVertical.data_ptr()],
                stream=Stream
            )

            nh = gradHorizontal.nelement()
            cupy_launch('kernel_DSepconv_updateGradHorizontal', cupy_kernel('kernel_DSepconv_updateGradHorizontal', {
                'gradLoss': gradOutput,
                'input': input,
                'vertical': vertical,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'mask': mask,
                'gradHorizontal': gradHorizontal
            }))(
                grid=tuple([int((nh + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[nh, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), offset_x.data_ptr(),
                      offset_y.data_ptr(), mask.data_ptr(), gradHorizontal.data_ptr()],
                stream=Stream
            )

            nx = gradOffsetX.nelement()
            cupy_launch('kernel_DSepconv_updateGradOffsetX', cupy_kernel('kernel_DSepconv_updateGradOffsetX', {
                'gradLoss': gradOutput,
                'input': input,
                'vertical': vertical,
                'horizontal': horizontal,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'mask': mask,
                'gradOffsetX': gradOffsetX
            }))(
                grid=tuple([int((nx + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[nx, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(),
                      offset_y.data_ptr(), mask.data_ptr(), gradOffsetX.data_ptr()],
                stream=Stream
            )

            ny = gradOffsetY.nelement()
            cupy_launch('kernel_DSepconv_updateGradOffsetY', cupy_kernel('kernel_DSepconv_updateGradOffsetY', {
                'gradLoss': gradOutput,
                'input': input,
                'vertical': vertical,
                'horizontal': horizontal,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'mask': mask,
                'gradOffsetX': gradOffsetY
            }))(
                grid=tuple([int((ny + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[ny, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(),
                      offset_x.data_ptr(),
                      offset_y.data_ptr(), mask.data_ptr(), gradOffsetY.data_ptr()],
                stream=Stream
            )

            nm = gradMask.nelement()
            cupy_launch('kernel_DSepconv_updateGradMask', cupy_kernel('kernel_DSepconv_updateGradMask', {
                'gradLoss': gradOutput,
                'input': input,
                'vertical': vertical,
                'horizontal': horizontal,
                'offset_x': offset_x,
                'offset_y': offset_y,
                'gradMask': gradMask
            }))(
                grid=tuple([int((nm + 512 - 1) / 512), 1, 1]),
                block=tuple([512, 1, 1]),
                args=[nm, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(),
                      offset_x.data_ptr(),
                      offset_y.data_ptr(), gradMask.data_ptr()],
                stream=Stream
            )

        elif input.is_cuda == False:
            raise NotImplementedError()

        # end

        return gradInput, gradVertical, gradHorizontal, gradOffsetX, gradOffsetY, gradMask


# end
# end

def FunctionDSepconv(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask):
    return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask)


# end

class ModuleDSepconv(torch.nn.Module):
    def __init__(self):
        super(ModuleDSepconv, self).__init__()

    # end

    def forward(self, tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask):
        return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask)
# end
# end

# float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 - (delta_x - floor(delta_x))) * (1 - (delta_y - floor(delta_y))) +
# 			                       VALUE_4(input, intSample, intDepth, top, right) * (delta_x - floor(delta_x)) *  (1 - (delta_y - floor(delta_y))) +
# 			                       VALUE_4(input, intSample, intDepth, bottom, left) * (1 - (delta_x - floor(delta_x))) * (delta_y - floor(delta_y)) +
# 			                       VALUE_4(input, intSample, intDepth, bottom, right) * (delta_x - floor(delta_x)) * (delta_y - floor(delta_y));

================================================
FILE: environment.yaml
================================================
name: ldmvfi
channels:
  - pytorch
  - defaults
  - conda-forge
dependencies:
  - python=3.9.13
  - pytorch=1.11.0
  - torchvision=0.12.0
  - cudatoolkit=11.3
  - pip:
    - opencv-python==4.6.0.66
    - pudb==2022.1.3
    - imageio==2.22.3
    - imageio-ffmpeg==0.4.7
    - pytorch-lightning==1.7.7
    - omegaconf==2.2.3
    - test-tube==0.7.5
    - streamlit==1.14.0
    - einops==0.5.0
    - torch-fidelity==0.3.0
    - transformers==4.23.1
    - timm==0.6.12
    - cupy
    - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
    - -e git+https://github.com/openai/CLIP.git@main#egg=clip
    - -e .

# conda create -n ldmvfi python=3.9
# conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
# pip install opencv-python==4.6.0.66 pudb==2022.1.3 imageio==2.22.3 imageio-ffmpeg==0.4.7 pytorch-lightning==1.7.7 omegaconf==2.2.3  test-tube==0.7.5 streamlit==1.14.0  einops==0.5.0 torch-fidelity==0.3.0 transformers==4.23.1
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
# pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
# pip install -e .
# pip install timm

================================================
FILE: evaluate.py
================================================
import argparse
import os
import torch
from functools import partial
from omegaconf import OmegaConf
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.data import testsets


parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')

parser.add_argument('--config', type=str, default=None)
parser.add_argument('--ckpt', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Middlebury_others')
parser.add_argument('--metrics', nargs='+', type=str, default=['PSNR', 'SSIM', 'LPIPS'])
parser.add_argument('--data_dir', type=str, default='D:\\')
parser.add_argument('--out_dir', type=str, default='eval_results')
parser.add_argument('--resume', dest='resume', default=False, action='store_true')

# sampler args
parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true')
parser.add_argument('--ddim_eta', type=float, default=1.0)
parser.add_argument('--ddim_steps', type=int, default=200)

def main():

    args = parser.parse_args()
    
    # initialise model
    config = OmegaConf.load(args.config)
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load(args.ckpt)['state_dict'])
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    model = model.eval()
    print('Model loaded successfully')

    # set up sampler
    if args.use_ddim:
        ddim = DDIMSampler(model)
        sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False)
    else:
        sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False)

    # setup output dirs
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    # initialise test set
    print('Testing on dataset: ', args.dataset)
    test_dir = os.path.join(args.out_dir, args.dataset)
    if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']:
        db_folder = args.dataset.split('_')[0].lower()
    else:
        db_folder = args.dataset.lower()
    test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder))
    if not os.path.exists(test_dir):
        os.mkdir(test_dir)
    test_db.eval(model, sample_func, metrics=args.metrics, output_dir=test_dir, resume=args.resume)



if __name__ == '__main__':
    main()

================================================
FILE: evaluate_vqm.py
================================================
import argparse
import os
from ldm.data import testsets_vqm


parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')

parser.add_argument('--exp', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Middlebury_others')
parser.add_argument('--metrics', nargs='+', type=str, default=['FloLPIPS'])
parser.add_argument('--data_dir', type=str, default='D:\\')
parser.add_argument('--out_dir', type=str, default='eval_results')
parser.add_argument('--resume', dest='resume', default=False, action='store_true')


def main():

    args = parser.parse_args()
    
    # initialise model
    model = args.exp
    print('Evaluating model:', model)

    # setup output dirs
    assert os.path.exists(args.out_dir), 'Frames not previously interpolated!'
    
    # initialise test set
    print('Testing on dataset: ', args.dataset)
    test_dir = os.path.join(args.out_dir, args.dataset)
    assert os.path.exists(test_dir), f'{args.dataset} not pre-computed!'

    if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']:
        db_folder = args.dataset.split('_')[0].lower()
    else:
        db_folder = args.dataset.lower()

    test_db = getattr(testsets_vqm, args.dataset)(os.path.join(args.data_dir, db_folder))
    test_db.eval(metrics=args.metrics, output_dir=test_dir, resume=args.resume)



if __name__ == '__main__':
    main()

================================================
FILE: interpolate_yuv.py
================================================
import argparse
import torch
import torchvision.transforms.functional as TF
import os
from PIL import Image
from tqdm import tqdm
import skvideo.io
from functools import partial
from utility import read_frame_yuv2rgb, tensor2rgb
from omegaconf import OmegaConf
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation')

parser.add_argument('--net', type=str, default='LDMVFI')
parser.add_argument('--config', type=str, default='configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml')
parser.add_argument('--ckpt', type=str, default='ckpt.pth')
parser.add_argument('--input_yuv', type=str, default='D:\\')
parser.add_argument('--size', type=str, default='1920x1080')
parser.add_argument('--out_fps', type=int, default=60)
parser.add_argument('--out_dir', type=str, default='.')

# sampler args
parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true')
parser.add_argument('--ddim_eta', type=float, default=1.0)
parser.add_argument('--ddim_steps', type=int, default=200)


def main():
    args = parser.parse_args()

    # initialise model
    config = OmegaConf.load(args.config)
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load(args.ckpt)['state_dict'])
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)
    model = model.eval()
    print('Model loaded successfully')

    # set up sampler
    if args.use_ddim:
        ddim = DDIMSampler(model)
        sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False)
    else:
        sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False)

    # Setup output file
    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    _, fname = os.path.split(args.input_yuv)
    seq_name = fname.strip('.yuv')
    width, height = args.size.split('x')
    bit_depth = 16 if '16bit' in fname else 10 if '10bit' in fname else 8
    pix_fmt = '444' if '444' in fname else '420'
    try:
        width = int(width)
        height = int(height)
    except:
        print('Invalid size, should be \'<width>x<height>\'')
        return 

    outname = '{}_{}x{}_{}fps_{}.mp4'.format(seq_name, width, height, args.out_fps, args.net)
    writer = skvideo.io.FFmpegWriter(os.path.join(args.out_dir, outname), 
        inputdict={
            '-r': str(args.out_fps)
        },
        outputdict={
            '-pix_fmt': 'yuv420p',
            '-s': '{}x{}'.format(width,height),
            '-r': str(args.out_fps),
            '-vcodec': 'libx264',  #use the h.264 codec
            '-crf': '0',           #set the constant rate factor to 0, which is lossless
            '-preset':'veryslow'   #the slower the better compression, in princple, try 
                                #other options see https://trac.ffmpeg.org/wiki/Encode/H.264
        }
    ) 

    # Start interpolation
    print('Using model {} to upsample file {}'.format(args.net, fname))
    stream = open(args.input_yuv, 'r')
    file_size = os.path.getsize(args.input_yuv)

    # YUV reading setup
    bytes_per_frame = width*height*1.5
    if pix_fmt == '444':
        bytes_per_frame *= 2
    if bit_depth != 8:
        bytes_per_frame *= 2

    num_frames = int(file_size // bytes_per_frame)
    rawFrame0 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, 0, bit_depth, pix_fmt))
    frame0 = TF.normalize(TF.to_tensor(rawFrame0), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()
    for t in tqdm(range(1, num_frames)):
        rawFrame1 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, t, bit_depth, pix_fmt))
        frame1 = TF.normalize(TF.to_tensor(rawFrame1), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda()

        with torch.no_grad():
            with model.ema_scope():
                # form condition tensor and define shape of latent rep
                xc = {'prev_frame': frame0, 'next_frame': frame1}
                c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
                shape = (model.channels, c.shape[2], c.shape[3])
                # run sampling and get denoised latent
                out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape)
                if isinstance(out, tuple): # using ddim
                    out = out[0]
                # reconstruct interpolated frame from latent
                out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

        # write to output video
        writer.writeFrame(tensor2rgb(frame0)[0])
        writer.writeFrame(tensor2rgb(out)[0])

        # update frame0
        frame0 = frame1
    
    # write the last frame
    writer.writeFrame(tensor2rgb(frame1)[0])

    stream.close()
    writer.close() # close the writer


if __name__ == "__main__":
    main()


================================================
FILE: ldm/data/__init__.py
================================================


================================================
FILE: ldm/data/bvi_vimeo.py
================================================
import numpy as np
import random
from os import listdir
from os.path import join, isdir, split, getsize
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
from PIL import Image
import ldm.data.vfitransforms as vt
from functools import partial

class Vimeo90k_triplet(Dataset):
    def __init__(self, db_dir, train=True,  crop_sz=(256,256), augment_s=True, augment_t=True):
        seq_dir = join(db_dir, 'sequences')
        self.crop_sz = crop_sz
        self.augment_s = augment_s
        self.augment_t = augment_t

        if train:
            seq_list_txt = join(db_dir, 'sep_trainlist.txt')
        else:
            seq_list_txt = join(db_dir, 'sep_testlist.txt')

        with open(seq_list_txt) as f:
            contents = f.readlines()
            seq_path = [line.strip() for line in contents if line != '\n']

        self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path]

    def __getitem__(self, index):
        rawFrame3 = Image.open(join(self.seq_path_list[index],  "im3.png"))
        rawFrame4 = Image.open(join(self.seq_path_list[index],  "im4.png"))
        rawFrame5 = Image.open(join(self.seq_path_list[index],  "im5.png"))

        if self.crop_sz is not None:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz)

        if self.augment_s:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5)
        
        if self.augment_t:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5)

        to_array = partial(np.array, dtype=np.float32)
        frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255

        frame3 = frame3/127.5 - 1.0
        frame4 = frame4/127.5 - 1.0
        frame5 = frame5/127.5 - 1.0

        return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5}

    def __len__(self):
        return len(self.seq_path_list)


class Vimeo90k_quintuplet(Dataset):
    def __init__(self, db_dir, train=True,  crop_sz=(256,256), augment_s=True, augment_t=True):
        seq_dir = join(db_dir, 'sequences')
        self.crop_sz = crop_sz
        self.augment_s = augment_s
        self.augment_t = augment_t

        if train:
            seq_list_txt = join(db_dir, 'sep_trainlist.txt')
        else:
            seq_list_txt = join(db_dir, 'sep_testlist.txt')

        with open(seq_list_txt) as f:
            contents = f.readlines()
            seq_path = [line.strip() for line in contents if line != '\n']

        self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path]

    def __getitem__(self, index):
        rawFrame1 = Image.open(join(self.seq_path_list[index],  "im1.png"))
        rawFrame3 = Image.open(join(self.seq_path_list[index],  "im3.png"))
        rawFrame4 = Image.open(join(self.seq_path_list[index],  "im4.png"))
        rawFrame5 = Image.open(join(self.seq_path_list[index],  "im5.png"))
        rawFrame7 = Image.open(join(self.seq_path_list[index],  "im7.png"))

        if self.crop_sz is not None:
            rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_crop(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, sz=self.crop_sz)

        if self.augment_s:
            rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
        
        if self.augment_t:
            rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)

        frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7))

        return frame1, frame3, frame4, frame5, frame7

    def __len__(self):
        return len(self.seq_path_list)

    
class BVIDVC_triplet(Dataset):
    def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True):

        db_dir = join(db_dir, 'quintuplets')
        self.crop_sz = crop_sz
        self.augment_s = augment_s
        self.augment_t = augment_t
        self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)]

    def __getitem__(self, index):

        cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png'))

        rawFrame3 = cat.crop((256, 0, 256*2, 256))
        rawFrame5 = cat.crop((256*2, 0, 256*3, 256))
        rawFrame4 = cat.crop((256*4, 0, 256*5, 256))

        if self.crop_sz is not None:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz)

        if self.augment_s:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5)
        
        if self.augment_t:
            rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5)

        to_array = partial(np.array, dtype=np.float32)
        frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255

        frame3 = frame3/127.5 - 1.0
        frame4 = frame4/127.5 - 1.0
        frame5 = frame5/127.5 - 1.0

        return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5}

    def __len__(self):
        return len(self.seq_path_list)


class BVIDVC_quintuplet(Dataset):
    def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True):

        db_dir = join(db_dir, 'quintuplets')
        self.crop_sz = crop_sz
        self.augment_s = augment_s
        self.augment_t = augment_t
        self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)]

    def __getitem__(self, index):

        cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png'))

        rawFrame1 = cat.crop((0, 0, 256, 256))
        rawFrame3 = cat.crop((256, 0, 256*2, 256))
        rawFrame5 = cat.crop((256*2, 0, 256*3, 256))
        rawFrame7 = cat.crop((256*3, 0, 256*4, 256))
        rawFrame4 = cat.crop((256*4, 0, 256*5, 256))

        if self.augment_s:
            rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)
        
        if self.augment_t:
            rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5)

        frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7))

        return frame1, frame3, frame4, frame5, frame7

    def __len__(self):
        return len(self.seq_path_list)


class Sampler(Dataset):
    def __init__(self, datasets, p_datasets=None, iter=False, samples_per_epoch=1000):
        self.datasets = datasets
        self.len_datasets = np.array([len(dataset) for dataset in self.datasets])
        self.p_datasets = p_datasets
        self.iter = iter

        if p_datasets is None:
            self.p_datasets = self.len_datasets / np.sum(self.len_datasets)

        self.samples_per_epoch = samples_per_epoch

        self.accum = [0,]
        for i, length in enumerate(self.len_datasets):
            self.accum.append(self.accum[-1] + self.len_datasets[i])

    def __getitem__(self, index):
        if self.iter:
            # iterate through all datasets
            for i in range(len(self.accum)):
                if index < self.accum[i]:
                    return self.datasets[i-1].__getitem__(index-self.accum[i-1])
        else:
            # first sample a dataset
            dataset = random.choices(self.datasets, self.p_datasets)[0]
            # sample a sequence from the dataset
            return dataset.__getitem__(random.randint(0,len(dataset)-1))
            

    def __len__(self):
        if self.iter:
            return int(np.sum(self.len_datasets))
        else:
            return self.samples_per_epoch


class BVI_Vimeo_triplet(Dataset):
    def __init__(self, db_dir, crop_sz=[256,256], p_datasets=None, iter=False, samples_per_epoch=1000):
        vimeo90k_train = Vimeo90k_triplet(join(db_dir, 'vimeo_septuplet'), train=True,  crop_sz=crop_sz)
        bvidvc_train = BVIDVC_triplet(join(db_dir, 'bvidvc'), crop_sz=crop_sz)

        self.datasets = [vimeo90k_train, bvidvc_train]
        self.len_datasets = np.array([len(dataset) for dataset in self.datasets])
        self.p_datasets = p_datasets
        self.iter = iter

        if p_datasets is None:
            self.p_datasets = self.len_datasets / np.sum(self.len_datasets)

        self.samples_per_epoch = samples_per_epoch

        self.accum = [0,]
        for i, length in enumerate(self.len_datasets):
            self.accum.append(self.accum[-1] + self.len_datasets[i])

    def __getitem__(self, index):
        if self.iter:
            # iterate through all datasets
            for i in range(len(self.accum)):
                if index < self.accum[i]:
                    return self.datasets[i-1].__getitem__(index-self.accum[i-1])
        else:
            # first sample a dataset
            dataset = random.choices(self.datasets, self.p_datasets)[0]
            # sample a sequence from the dataset
            return dataset.__getitem__(random.randint(0,len(dataset)-1))
            

    def __len__(self):
        if self.iter:
            return int(np.sum(self.len_datasets))
        else:
            return self.samples_per_epoch

================================================
FILE: ldm/data/testsets.py
================================================
import glob
from typing import List
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.utils import save_image as imwrite
import os
from os.path import join, exists
import utility
import numpy as np
import ast
import time
from ldm.models.autoencoder import * 


class TripletTestSet:
    def __init__(self):
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]

    def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False):
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results.txt'), 'a')
        for idx in range(len(self.im_list)):
            if resume and idx < start_idx:
                assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {self.im_list[idx]}')
            t0 = time.time()
            if not exists(join(output_dir, self.im_list[idx])):
                os.makedirs(join(output_dir, self.im_list[idx]))

            with torch.no_grad():
                with model.ema_scope():
                    # form condition tensor and define shape of latent rep
                    xc = {'prev_frame': self.input0_list[idx], 'next_frame': self.input1_list[idx]}
                    c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
                    shape = (model.channels, c.shape[2], c.shape[3])
                    # run sampling and get denoised latent rep
                    out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
                    if isinstance(out, tuple): # using ddim
                        out = out[0]
                    # reconstruct interpolated frame from latent
                    out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                    out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

            gt = self.gt_list[idx]

            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input0_list[idx], self.input1_list[idx]])[0].item()
                results_dict[metric].append(score)

            imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True)

            msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)

        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()

class Middlebury_others(TripletTestSet):
    def __init__(self, db_dir):
        super(Middlebury_others, self).__init__()
        self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
            self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
            self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))

class Davis(TripletTestSet):
    def __init__(self, db_dir):
        super(Davis, self).__init__()
        self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn']

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))


class Ucf(TripletTestSet):
    def __init__(self, db_dir):
        super(Ucf, self).__init__()
        self.im_list = os.listdir(db_dir)

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0))


class Snufilm(TripletTestSet):
    def __init__(self, db_dir, mode):
        super(Snufilm, self).__init__()     
        self.mode = mode
        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f:
            triplet_list = f.read().splitlines()
        self.im_list = []
        for i, triplet in enumerate(triplet_list, 1):
            self.im_list.append('{}-{}'.format(mode, str(i).zfill(3)))
            lst = triplet.split(' ')
            self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0))


class Snufilm_easy(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-5]
        super(Snufilm_easy, self).__init__(db_dir, 'easy')

class Snufilm_medium(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-7]
        super(Snufilm_medium, self).__init__(db_dir, 'medium')

class Snufilm_hard(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-5]
        super(Snufilm_hard, self).__init__(db_dir, 'hard')

class Snufilm_extreme(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-8]
        super(Snufilm_extreme, self).__init__(db_dir, 'extreme')

class VFITex_triplet:
    def __init__(self, db_dir):
        self.seq_list = os.listdir(db_dir)
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]


    def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False):
        model.eval()
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results.txt'), 'a')

        for idx, seq in enumerate(self.seq_list):
            if resume and idx < start_idx:
                assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {seq}')
            t0 = time.time()

            seqpath = join(self.db_dir, seq)
            if not exists(join(output_dir, seq)):
                os.makedirs(join(output_dir, seq))

            # interpolate between every 2 frames
            gt_list, out_list, inputs_list = [], [], []
            tmp_dict = {k : [] for k in metrics}
            num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')])
            for t in range(1, num_frames-5, 2):
                im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png'))
                im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png'))
                im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png'))
                # center crop if 4K
                if '4K' in seq:
                    w, h  = im0.size
                    im0 = TF.center_crop(im0, (h//2, w//2))
                    im1 = TF.center_crop(im1, (h//2, w//2))
                    im2 = TF.center_crop(im2, (h//2, w//2))
                im0 = self.transform(im0).cuda().unsqueeze(0)
                im1 = self.transform(im1).cuda().unsqueeze(0)
                im2 = self.transform(im2).cuda().unsqueeze(0)

                with torch.no_grad():
                    with model.ema_scope():
                        # form condition tensor and define shape of latent rep
                        xc = {'prev_frame': im0, 'next_frame': im2}
                        c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
                        shape = (model.channels, c.shape[2], c.shape[3])
                        # run sampling and get denoised latent rep
                        out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
                        if isinstance(out, tuple): # using ddim
                            out = out[0]
                        # reconstruct interpolated frame from latent
                        out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                        out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

                for metric in metrics:
                    score = getattr(utility, 'calc_{}'.format(metric.lower()))(im1, out, [im0, im2])[0].item()
                    tmp_dict[metric].append(score)

                imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True)

            # compute sequence-level scores
            for metric in metrics:
                results_dict[metric].append(np.mean(tmp_dict[metric]))

            # log
            msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)
        
        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()




class Davis90_triplet:
    def __init__(self, db_dir):
        self.seq_list = sorted(os.listdir(db_dir))
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]


    def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False):
        model.eval()
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results.txt'), 'a')

        for idx, seq in enumerate(self.seq_list):
            if resume and idx < start_idx:
                assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {seq}')
            t0 = time.time()

            seqpath = join(self.db_dir, seq)
            if not exists(join(output_dir, seq)):
                os.makedirs(join(output_dir, seq))

            # interpolate between every 2 frames
            gt_list, out_list, inputs_list = [], [], []
            tmp_dict = {k : [] for k in metrics}
            num_frames = len(os.listdir(seqpath))
            for t in range(0, num_frames-6, 2):
                im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg'))
                im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg'))
                im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg'))

                im3 = self.transform(im3).cuda().unsqueeze(0)
                im4 = self.transform(im4).cuda().unsqueeze(0)
                im5 = self.transform(im5).cuda().unsqueeze(0)

                with torch.no_grad():
                    with model.ema_scope():
                        # form condition tensor and define shape of latent rep
                        xc = {'prev_frame': im3, 'next_frame': im5}
                        c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
                        shape = (model.channels, c.shape[2], c.shape[3])
                        # run sampling and get denoised latent rep
                        out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
                        if isinstance(out, tuple): # using ddim
                            out = out[0]
                        # reconstruct interpolated frame from latent
                        out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                        out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

                for metric in metrics:
                    score = getattr(utility, 'calc_{}'.format(metric.lower()))(im4, out, [im3, im5])[0].item()
                    tmp_dict[metric].append(score)

                imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True)

            # compute sequence-level scores
            for metric in metrics:
                results_dict[metric].append(np.mean(tmp_dict[metric]))

            # log
            msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)
        
        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()



class Ucf101_triplet:
    def __init__(self, db_dir):
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]

        self.im_list = os.listdir(db_dir)

        self.input3_list = []
        self.input5_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0))
            self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0))

    def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False):
        model.eval()
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results.txt'), 'a')

        for idx in range(len(self.im_list)):
            if resume and idx < start_idx:
                assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {self.im_list[idx]}')
            t0 = time.time()

            if not exists(join(output_dir, self.im_list[idx])):
                os.makedirs(join(output_dir, self.im_list[idx]))

            with torch.no_grad():
                with model.ema_scope():
                    # form condition tensor and define shape of latent rep
                    xc = {'prev_frame': self.input3_list[idx], 'next_frame': self.input5_list[idx]}
                    c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc)
                    shape = (model.channels, c.shape[2], c.shape[3])
                    # run sampling and get denoised latent rep
                    out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None)
                    if isinstance(out, tuple): # using ddim
                        out = out[0]
                    # reconstruct interpolated frame from latent
                    out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list)
                    out =  torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1]

            gt = self.gt_list[idx]

            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input3_list[idx], self.input5_list[idx]])[0].item()
                results_dict[metric].append(score)

            imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True)

            msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)

        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()

================================================
FILE: ldm/data/testsets_vqm.py
================================================
import glob
from typing import List
from PIL import Image
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from torchvision.utils import save_image as imwrite
import os
from os.path import join, exists
import utility
import numpy as np
import ast
import time


class TripletTestSet:
    def __init__(self):
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]

    def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False):
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'No res file found to resume from!'
            with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results_vqm.txt'), 'a')

        for idx in range(len(self.im_list)):
            if resume and idx < start_idx:
                assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {self.im_list[idx]}')
            t0 = time.time()
            assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}'

            out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0)
            gt = self.gt_list[idx]

            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input0_list[idx], self.input1_list[idx]])
                results_dict[metric].append(score)


            msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)

        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()

class Middlebury_others(TripletTestSet):
    def __init__(self, db_dir):
        super(Middlebury_others, self).__init__()
        self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
            self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1]
            self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))

class Davis(TripletTestSet):
    def __init__(self, db_dir):
        super(Davis, self).__init__()
        self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn']

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0))


class Ucf(TripletTestSet):
    def __init__(self, db_dir):
        super(Ucf, self).__init__()
        self.im_list = os.listdir(db_dir)

        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0))


class Snufilm(TripletTestSet):
    def __init__(self, db_dir, mode):
        super(Snufilm, self).__init__()     
        self.mode = mode
        self.input0_list = []
        self.input1_list = []
        self.gt_list = []
        with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f:
            triplet_list = f.read().splitlines()
        self.im_list = []
        for i, triplet in enumerate(triplet_list, 1):
            self.im_list.append('{}-{}'.format(mode, str(i).zfill(3)))
            lst = triplet.split(' ')
            self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0))
            self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0))


class Snufilm_easy(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-5]
        super(Snufilm_easy, self).__init__(db_dir, 'easy')

class Snufilm_medium(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-7]
        super(Snufilm_medium, self).__init__(db_dir, 'medium')

class Snufilm_hard(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-5]
        super(Snufilm_hard, self).__init__(db_dir, 'hard')

class Snufilm_extreme(Snufilm):
    def __init__(self, db_dir):
        db_dir = db_dir[:-8]
        super(Snufilm_extreme, self).__init__(db_dir, 'extreme')

class VFITex_triplet:
    def __init__(self, db_dir):
        self.seq_list = os.listdir(db_dir)
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]


    def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False):
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results_vqm.txt'), 'a')

        for idx, seq in enumerate(self.seq_list):
            if resume and idx < start_idx:
                assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {seq}')
            t0 = time.time()

            seqpath = join(self.db_dir, seq)
            assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}'

            # interpolate between every 2 frames
            gt_list, out_list, inputs_list = [], [], []
            num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')])
            for t in range(1, num_frames-5, 2):
                im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png'))
                im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png'))
                im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png'))
                # center crop if 4K
                if '4K' in seq:
                    w, h  = im0.size
                    im0 = TF.center_crop(im0, (h//2, w//2))
                    im1 = TF.center_crop(im1, (h//2, w//2))
                    im2 = TF.center_crop(im2, (h//2, w//2))
                im0 = self.transform(im0).cuda().unsqueeze(0)
                im1 = self.transform(im1).cuda().unsqueeze(0)
                im2 = self.transform(im2).cuda().unsqueeze(0)

                out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0)
                gt_list.append(im1)
                out_list.append(out)
                if t == 1:
                    inputs_list.append(im0)
                inputs_list.append(im2)

            # compute sequence-level scores
            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list)
                results_dict[metric].append(score)

            # log
            msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)
        
        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()




class Davis90_triplet:
    def __init__(self, db_dir):
        self.seq_list = sorted(os.listdir(db_dir))
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]


    def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False):
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results_vqm.txt'), 'a')

        for idx, seq in enumerate(self.seq_list):
            if resume and idx < start_idx:
                assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {seq}')
            t0 = time.time()

            seqpath = join(self.db_dir, seq)
            assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}'

            # interpolate between every 2 frames
            gt_list, out_list, inputs_list = [], [], []
            num_frames = len(os.listdir(seqpath))
            for t in range(0, num_frames-6, 2):
                im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg'))
                im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg'))
                im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg'))

                im3 = self.transform(im3).cuda().unsqueeze(0)
                im4 = self.transform(im4).cuda().unsqueeze(0)
                im5 = self.transform(im5).cuda().unsqueeze(0)

                out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0)
                gt_list.append(im4)
                out_list.append(out)
                if t == 0:
                    inputs_list.append(im3)
                inputs_list.append(im5)

            # compute sequence-level scores
            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list)
                results_dict[metric].append(score)

            # log
            msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)
        
        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()



class Ucf101_triplet:
    def __init__(self, db_dir):
        self.db_dir = db_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1]

        self.im_list = os.listdir(db_dir)

        self.input3_list = []
        self.input5_list = []
        self.gt_list = []
        for item in self.im_list:
            self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0))
            self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0))
            self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0))

    def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False):
        results_dict = {k : [] for k in metrics}

        start_idx = 0
        if resume:
            # fill in results_dict with prev results and find where to start from
            assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!'
            with open(join(output_dir, 'results_vqm.txt'), 'r') as f:
                prev_lines = f.readlines()
                for line in prev_lines:
                    if len(line) < 2:
                        continue
                    cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string
                    for k in metrics:
                        results_dict[k].append(float(cur_res[k]))
                    start_idx += 1
        
        logfile = open(join(output_dir, 'results_vqm.txt'), 'a')

        for idx in range(len(self.im_list)):
            if resume and idx < start_idx:
                assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!'
                continue

            print(f'Evaluating {self.im_list[idx]}')
            t0 = time.time()

            assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}'

            out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0)
            gt = self.gt_list[idx]

            for metric in metrics:
                score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input3_list[idx], self.input5_list[idx]])
                results_dict[metric].append(score)


            msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f'    time taken: {round(time.time()-t0,2)}' + '\n'
            print(msg, end='')
            logfile.write(msg)

        msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n'
        print(msg, end='')
        logfile.write(msg)
        logfile.close()

================================================
FILE: ldm/data/vfitransforms.py
================================================
import random
import torch
import torchvision
import torchvision.transforms.functional as TF


def rand_crop(*args, sz):
    i, j, h, w = torchvision.transforms.RandomCrop.get_params(args[0], output_size=sz)
    out = []
    for im in args:
        out.append(TF.crop(im, i, j, h, w))
    return out


def rand_flip(*args, p):
    out = list(args)
    if random.random() < p:
        for i, im in enumerate(out):
            out[i] = TF.hflip(im)
    if random.random() < p:
        for i, im in enumerate(out):
            out[i] = TF.vflip(im)
    return out


def rand_reverse(*args, p):
    if random.random() < p:
        return args[::-1]
    else:
        return args

================================================
FILE: ldm/lr_scheduler.py
================================================
import numpy as np


class LambdaWarmUpCosineScheduler:
    """
    note: use with a base_lr of 1.0
    """
    def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
        self.lr_warm_up_steps = warm_up_steps
        self.lr_start = lr_start
        self.lr_min = lr_min
        self.lr_max = lr_max
        self.lr_max_decay_steps = max_decay_steps
        self.last_lr = 0.
        self.verbosity_interval = verbosity_interval

    def schedule(self, n, **kwargs):
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
        if n < self.lr_warm_up_steps:
            lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
            self.last_lr = lr
            return lr
        else:
            t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
            t = min(t, 1.0)
            lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
                    1 + np.cos(t * np.pi))
            self.last_lr = lr
            return lr

    def __call__(self, n, **kwargs):
        return self.schedule(n,**kwargs)


class LambdaWarmUpCosineScheduler2:
    """
    supports repeated iterations, configurable via lists
    note: use with a base_lr of 1.0.
    """
    def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
        assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
        self.lr_warm_up_steps = warm_up_steps
        self.f_start = f_start
        self.f_min = f_min
        self.f_max = f_max
        self.cycle_lengths = cycle_lengths
        self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
        self.last_f = 0.
        self.verbosity_interval = verbosity_interval

    def find_in_interval(self, n):
        interval = 0
        for cl in self.cum_cycles[1:]:
            if n <= cl:
                return interval
            interval += 1

    def schedule(self, n, **kwargs):
        cycle = self.find_in_interval(n)
        n = n - self.cum_cycles[cycle]
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
                                                       f"current cycle {cycle}")
        if n < self.lr_warm_up_steps[cycle]:
            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
            self.last_f = f
            return f
        else:
            t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
            t = min(t, 1.0)
            f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
                    1 + np.cos(t * np.pi))
            self.last_f = f
            return f

    def __call__(self, n, **kwargs):
        return self.schedule(n, **kwargs)


class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):

    def schedule(self, n, **kwargs):
        cycle = self.find_in_interval(n)
        n = n - self.cum_cycles[cycle]
        if self.verbosity_interval > 0:
            if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
                                                       f"current cycle {cycle}")

        if n < self.lr_warm_up_steps[cycle]:
            f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
            self.last_f = f
            return f
        else:
            f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
            self.last_f = f
            return f



================================================
FILE: ldm/models/autoencoder.py
================================================
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from packaging import version
from ldm.modules.ema import LitEma
from contextlib import contextmanager

from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer

from ldm.modules.diffusionmodules.model import *

from ldm.util import instantiate_from_config



class VQFlowNet(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 batch_resize_range=None,
                 scheduler_config=None,
                 lr_g_factor=1.0,
                 remap=None,
                 sane_index_shape=False, # tell vector quantizer to return indices as bhw
                 use_ema=False
                 ):
        super().__init__()
        self.embed_dim = embed_dim # 3
        self.n_embed = n_embed # 8192
        self.image_key = image_key # 'image'
        self.encoder = FlowEncoder(**ddconfig)
        self.decoder = FlowDecoderWithResidual(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=remap,
                                        sane_index_shape=sane_index_shape)
        self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        self.batch_resize_range = batch_resize_range
        if self.batch_resize_range is not None:
            print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.scheduler_config = scheduler_config
        self.lr_g_factor = lr_g_factor
        self.h0 = None
        self.w0 = None
        self.h_padded = None
        self.w_padded = None
        self.pad_h = 0
        self.pad_w = 0

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
            print(f"Unexpected Keys: {unexpected}")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x, ret_feature=False):
        '''
        Set ret_feature = True when encoding conditions in ddpm
        '''
        # Pad the input first so its size is deividable by 8.
        # this is to tolerate different f values, various size inputs, 
        # and some operations in the DDPM unet model.
        self.h0, self.w0 = x.shape[2:]
        # 8: window size for max vit
        # 2**(nr-1): f 
        # 4: factor of downsampling in DDPM unet
        min_side = 8 * 2**(self.encoder.num_resolutions-1) * 4
        if self.h0 % min_side != 0:
            pad_h = min_side - (self.h0 % min_side)
            if pad_h == self.h0: # this is to avoid padding 256 patches
                pad_h = 0
            x = F.pad(x, (0, 0, 0, pad_h), mode='reflect')
            self.h_padded = True
            self.pad_h = pad_h

        if self.w0 % min_side != 0:
            pad_w = min_side - (self.w0 % min_side)
            if pad_w == self.w0:
                pad_w = 0
            x = F.pad(x, (0, pad_w, 0, 0), mode='reflect')
            self.w_padded = True
            self.pad_w = pad_w

        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, ret_feature)
        else:
            h = self.encoder(x)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        if ret_feature:
            return quant, emb_loss, info, phi_list
        return quant, emb_loss, info

    def encode_to_prequant(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h

    def decode(self, quant, x_prev, x_next):
        cond_dict = dict(
            phi_prev_list = self.encode(x_prev, ret_feature=True)[-1],
            phi_next_list = self.encode(x_next, ret_feature=True)[-1],
            frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'),
            frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect')
        )
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, cond_dict)

        # check if image is padded and return the original part only
        if self.h_padded:
            dec = dec[:, :, 0:self.h0, :]
        if self.w_padded:
            dec = dec[:, :, :, 0:self.w0]
        return dec

    def decode_code(self, code_b):
        quant_b = self.quantize.embed_code(code_b)
        dec = self.decode(quant_b)
        return dec

    def forward(self, input, x_prev, x_next, return_pred_indices=False):

        quant, diff, (_,_,ind) = self.encode(input)
        dec = self.decode(quant, x_prev, x_next)
        if return_pred_indices:
            return dec, diff, ind
        return dec, diff

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        if self.batch_resize_range is not None:
            lower_size = self.batch_resize_range[0]
            upper_size = self.batch_resize_range[1]
            if self.global_step <= 4:
                # do the first few batches with max size to avoid later oom
                new_resize = upper_size
            else:
                new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
            if new_resize != x.shape[2]:
                x = F.interpolate(x, size=new_resize, mode="bicubic")
            x = x.detach()
        return x

    def training_step(self, batch, batch_idx, optimizer_idx):
        # https://github.com/pytorch/pytorch/issues/37142
        # try not to fool the heuristics
        x = self.get_input(batch, self.image_key)
        x_prev = self.get_input(batch, 'prev_frame')
        x_next = self.get_input(batch, 'next_frame')
        xrec, qloss = self(x, x_prev, x_next)

        if optimizer_idx == 0:
            # autoencode
            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")

            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return aeloss

        if optimizer_idx == 1:
            # discriminator
            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
                                            last_layer=self.get_last_layer(), split="train")
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
            return discloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, suffix=""):
        x = self.get_input(batch, self.image_key)
        x_prev = self.get_input(batch, 'prev_frame')
        x_next = self.get_input(batch, 'next_frame')
        xrec, qloss = self(x, x_prev, x_next)
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
                                        self.global_step,
                                        last_layer=self.get_last_layer(),
                                        split="val"+suffix,
                                        )

        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
                                            self.global_step,
                                            last_layer=self.get_last_layer(),
                                            split="val"+suffix,
                                            )
        rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
        self.log(f"val{suffix}/rec_loss", rec_loss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"val{suffix}/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        if version.parse(pl.__version__) >= version.parse('1.4.0'):
            del log_dict_ae[f"val{suffix}/rec_loss"]
        self.log_dict(log_dict_ae)
        self.log_dict(log_dict_disc)
        return self.log_dict

    def configure_optimizers(self):
        lr_d = self.learning_rate
        lr_g = self.lr_g_factor*self.learning_rate
        print("lr_d", lr_d)
        print("lr_g", lr_g)
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quantize.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr_g, betas=(0.5, 0.9))
        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
                                    lr=lr_d, betas=(0.5, 0.9))

        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
                {
                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                },
            ]
            return [opt_ae, opt_disc], scheduler
        return [opt_ae, opt_disc], []

    def get_last_layer(self):
        return self.decoder.conv_out.weight

    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x_prev = self.get_input(batch, 'prev_frame')
        x_next = self.get_input(batch, 'next_frame')
        x = x.to(self.device)
        if only_inputs:
            log["inputs"] = x
            return log
        xrec, _ = self(x, x_prev, x_next)
        if x.shape[1] > 3:
            # colorize with random projection
            assert xrec.shape[1] > 3
            x = self.to_rgb(x)
            xrec = self.to_rgb(xrec)
        log["inputs"] = x
        log["reconstructions"] = xrec
        if plot_ema:
            with self.ema_scope():
                xrec_ema, _ = self(x, x_prev, x_next)
                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
                log["reconstructions_ema"] = xrec_ema
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x

    
class VQFlowNetInterface(VQFlowNet):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def encode(self, x, ret_feature=False):
        '''
        Set ret_feature = True when encoding conditions in ddpm
        '''
        # Pad the input first so its size is deividable by 8.
        # this is to tolerate different f values, various size inputs, 
        # and some operations in the DDPM unet model.
        self.h0, self.w0 = x.shape[2:]
        # 8: window size for max vit
        # 2**(nr-1): f 
        # 4: factor of downsampling in DDPM unet
        min_side = 512#8 * 2**(self.encoder.num_resolutions-1) * 16
        min_side = min_side // 2 if self.h0 <= 256 else min_side
        if self.h0 % min_side != 0:
            pad_h = min_side - (self.h0 % min_side)
            if pad_h == self.h0: # this is to avoid padding 256 patches
                pad_h = 0
            x = F.pad(x, (0, 0, 0, pad_h), mode='reflect')
            self.h_padded = True
            self.pad_h = pad_h

        if self.w0 % min_side != 0:
            pad_w = min_side - (self.w0 % min_side)
            if pad_w == self.w0:
                pad_w = 0
            x = F.pad(x, (0, pad_w, 0, 0), mode='reflect')
            self.w_padded = True
            self.pad_w = pad_w

        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, ret_feature)
        else:
            h = self.encoder(x)
        h = self.quant_conv(h)
        if ret_feature:
            return h, phi_list
        return h

    def decode(self, h, x_prev, x_next, phi_prev_list, phi_next_list, force_not_quantize=False):
        # also go through quantization layer
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
        cond_dict = dict(
            phi_prev_list = phi_prev_list,
            phi_next_list = phi_next_list,
            frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'),
            frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect')
        )
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, cond_dict)

        # check if image is padded and return the original part only
        if self.h_padded:
            dec = dec[:, :, 0:self.h0, :]
        if self.w_padded:
            dec = dec[:, :, :, 0:self.w0]
        return dec

================================================
FILE: ldm/models/diffusion/__init__.py
================================================


================================================
FILE: ldm/models/diffusion/ddim.py
================================================
"""SAMPLING ONLY."""

import torch
import numpy as np
from tqdm import tqdm
from functools import partial

from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like


class DDIMSampler(object):
    def __init__(self, model, schedule="linear", **kwargs):
        super().__init__()
        self.model = model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
                                                  num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
        alphas_cumprod = self.model.alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)

        self.register_buffer('betas', to_torch(self.model.betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self,
               S,
               batch_size,
               shape,
               conditioning=None,
               callback=None,
               normals_sequence=None,
               img_callback=None,
               quantize_x0=False,
               eta=0.,
               mask=None,
               x0=None,
               temperature=1.,
               noise_dropout=0.,
               score_corrector=None,
               corrector_kwargs=None,
               verbose=True,
               x_T=None,
               log_every_t=100,
               unconditional_guidance_scale=1.,
               unconditional_conditioning=None,
               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
               **kwargs
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")

        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        if verbose:
            print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback=callback,
                                                    img_callback=img_callback,
                                                    quantize_denoised=quantize_x0,
                                                    mask=mask, x0=x0,
                                                    ddim_use_original_steps=False,
                                                    noise_dropout=noise_dropout,
                                                    temperature=temperature,
                                                    score_corrector=score_corrector,
                                                    corrector_kwargs=corrector_kwargs,
                                                    x_T=x_T,
                                                    log_every_t=log_every_t,
                                                    unconditional_guidance_scale=unconditional_guidance_scale,
                                                    unconditional_conditioning=unconditional_conditioning,
                                                    verbose=verbose
                                                    )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None, ddim_use_original_steps=False,
                      callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,verbose=True):
        device = self.model.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        if timesteps is None:
            timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
        elif timesteps is not None and not ddim_use_original_steps:
            subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
            timesteps = self.ddim_timesteps[:subset_end]

        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        if verbose:
            print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) if verbose else time_range

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img

            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            e_t = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            c_in = torch.cat([unconditional_conditioning, c])
            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

        if score_corrector is not None:
            assert self.model.parameterization == "eps"
            e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        if quantize_denoised:
            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0


================================================
FILE: ldm/models/diffusion/ddpm.py
================================================
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""

import torch
import torch.nn as nn
import numpy as np
import pytorch_lightning as pl
from torch.optim.lr_scheduler import LambdaLR
from einops import rearrange, repeat
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from torchvision.utils import make_grid
from pytorch_lightning.utilities.distributed import rank_zero_only

from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
from ldm.modules.ema import LitEma
from ldm.models.autoencoder import * 
from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
from ldm.models.diffusion.ddim import DDIMSampler


__conditioning_keys__ = {'concat': 'c_concat',
                         'crossattn': 'c_crossattn',
                         'adm': 'y'}


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self



class DDPM(pl.LightningModule):
    # classic DDPM with Gaussian diffusion, in image space
    def __init__(self,
                 unet_config,
                 timesteps=1000,
                 beta_schedule="linear",
                 loss_type="l2",
                 ckpt_path=None,
                 ignore_keys=[],
                 load_only_unet=False,
                 monitor="val/loss",
                 use_ema=True,
                 first_stage_key="image",
                 image_size=256,
                 channels=3,
                 log_every_t=100,
                 clip_denoised=True,
                 linear_start=1e-4,
                 linear_end=2e-2,
                 cosine_s=8e-3,
                 given_betas=None,
                 original_elbo_weight=0.,
                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
                 l_simple_weight=1.,
                 conditioning_key=None,
                 parameterization="eps",  # all assuming fixed variance schedules
                 scheduler_config=None,
                 use_positional_encodings=False,
                 learn_logvar=False,
                 logvar_init=0.,
                 ):
        super().__init__()
        assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
        self.parameterization = parameterization
        print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
        self.cond_stage_model = None
        self.clip_denoised = clip_denoised
        self.log_every_t = log_every_t
        self.first_stage_key = first_stage_key
        self.image_size = image_size  # try conv?
        self.channels = channels
        self.use_positional_encodings = use_positional_encodings
        self.model = DiffusionWrapper(unet_config, conditioning_key)
        count_params(self.model, verbose=True)
        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self.model)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.use_scheduler = scheduler_config is not None
        if self.use_scheduler:
            self.scheduler_config = scheduler_config

        self.v_posterior = v_posterior
        self.original_elbo_weight = original_elbo_weight
        self.l_simple_weight = l_simple_weight

        if monitor is not None:
            self.monitor = monitor
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)

        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)

        self.loss_type = loss_type

        self.learn_logvar = learn_logvar
        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
        if self.learn_logvar:
            self.logvar = nn.Parameter(self.logvar, requires_grad=True)


    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        if exists(given_betas):
            betas = given_betas
        else:
            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                       cosine_s=cosine_s)
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
                    1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

        if self.parameterization == "eps":
            lvlb_weights = self.betas ** 2 / (
                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        elif self.parameterization == "x0":
            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        else:
            raise NotImplementedError("mu not supported")
        # TODO how to choose this term
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.model.parameters())
            self.model_ema.copy_to(self.model)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.model.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location="cpu")
        if "state_dict" in list(sd.keys()):
            sd = sd["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
            sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
        if len(unexpected) > 0:
            print(f"Unexpected Keys: {unexpected}")

    def q_mean_variance(self, x_start, t):
        """
        Get the distribution q(x_t | x_0).
        :param x_start: the [N x C x ...] tensor of noiseless inputs.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :return: A tuple (mean, variance, log_variance), all of x_start's shape.
        """
        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
        return mean, variance, log_variance

    def predict_start_from_noise(self, x_t, t, noise):
        return (
                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
        )

    def q_posterior(self, x_start, x_t, t):
        posterior_mean = (
                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
        )
        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
        return posterior_mean, posterior_variance, posterior_log_variance_clipped

    def p_mean_variance(self, x, t, clip_denoised: bool):
        model_out = self.model(x, t)
        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        if clip_denoised:
            x_recon.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
        noise = noise_like(x.shape, device, repeat_noise)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def p_sample_loop(self, shape, return_intermediates=False):
        device = self.betas.device
        b = shape[0]
        img = torch.randn(shape, device=device)
        intermediates = [img]
        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                                clip_denoised=self.clip_denoised)
            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
                intermediates.append(img)
        if return_intermediates:
            return img, intermediates
        return img

    @torch.no_grad()
    def sample(self, batch_size=16, return_intermediates=False):
        image_size = self.image_size
        channels = self.channels
        return self.p_sample_loop((batch_size, channels, image_size, image_size),
                                  return_intermediates=return_intermediates)

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

    def get_loss(self, pred, target, mean=True):
        if self.loss_type == 'l1':
            loss = (target - pred).abs()
            if mean:
                loss = loss.mean()
        elif self.loss_type == 'l2':
            if mean:
                loss = torch.nn.functional.mse_loss(target, pred)
            else:
                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
        else:
            raise NotImplementedError("unknown loss type '{loss_type}'")

        return loss

    def p_losses(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_out = self.model(x_noisy, t)

        loss_dict = {}
        if self.parameterization == "eps":
            target = noise
        elif self.parameterization == "x0":
            target = x_start
        else:
            raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")

        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])

        log_prefix = 'train' if self.training else 'val'

        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
        loss_simple = loss.mean() * self.l_simple_weight

        loss_vlb = (self.lvlb_weights[t] * loss).mean()
        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})

        loss = loss_simple + self.original_elbo_weight * loss_vlb

        loss_dict.update({f'{log_prefix}/loss': loss})

        return loss, loss_dict

    def forward(self, x, *args, **kwargs):
        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        return self.p_losses(x, t, *args, **kwargs)

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = rearrange(x, 'b h w c -> b c h w')
        x = x.to(memory_format=torch.contiguous_format).float()
        return x

    def shared_step(self, batch):
        x = self.get_input(batch, self.first_stage_key)
        loss, loss_dict = self(x)
        return loss, loss_dict

    def training_step(self, batch, batch_idx):
        loss, loss_dict = self.shared_step(batch)

        self.log_dict(loss_dict, prog_bar=True,
                      logger=True, on_step=True, on_epoch=True)

        self.log("global_step", float(self.global_step),
                 prog_bar=True, logger=True, on_step=True, on_epoch=False)

        if self.use_scheduler:
            lr = self.optimizers().param_groups[0]['lr']
            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)

        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        _, loss_dict_no_ema = self.shared_step(batch)
        with self.ema_scope():
            _, loss_dict_ema = self.shared_step(batch)
            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self.model)

    def _get_rows_from_list(self, samples):
        n_imgs_per_row = len(samples)
        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
        log = dict()
        x = self.get_input(batch, self.first_stage_key)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        x = x.to(self.device)[:N]
        log["inputs"] = x

        # get diffusion row
        diffusion_row = list()
        x_start = x[:n_row]

        for t in range(self.num_timesteps):
            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                t = t.to(self.device).long()
                noise = torch.randn_like(x_start)
                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
                diffusion_row.append(x_noisy)

        log["diffusion_row"] = self._get_rows_from_list(diffusion_row)

        if sample:
            # get denoise row
            with self.ema_scope("Plotting"):
                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)

            log["samples"] = samples
            log["denoise_row"] = self._get_rows_from_list(denoise_row)

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.learn_logvar:
            params = params + [self.logvar]
        opt = torch.optim.AdamW(params, lr=lr)
        return opt


class LatentDiffusion(DDPM):
    """main class"""
    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 num_timesteps_cond=None,
                 cond_stage_key="image",
                 cond_stage_trainable=False,
                 concat_mode=True,
                 cond_stage_forward=None,
                 conditioning_key=None,
                 scale_factor=1.0,
                 *args, **kwargs):
        self.num_timesteps_cond = default(num_timesteps_cond, 1)
        assert self.num_timesteps_cond <= kwargs['timesteps']
        # for backwards compatibility after implementation of DiffusionWrapper
        if conditioning_key is None:
            conditioning_key = 'concat' if concat_mode else 'crossattn'
        if cond_stage_config == '__is_unconditional__':
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
            self.num_downs = 0
        self.scale_factor = scale_factor

        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)
        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False

        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True

    def make_cond_schedule(self, ):
        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
        self.cond_ids[:self.num_timesteps_cond] = ids


    def register_schedule(self,
                          given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)

        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def instantiate_cond_stage(self, config):
        if not self.cond_stage_trainable:
            if config == "__is_first_stage__":
                print("Using first stage also as cond stage.")
                self.cond_stage_model = self.first_stage_model
            elif config == "__is_unconditional__":
                print(f"Training {self.__class__.__name__} as an unconditional model.")
                self.cond_stage_model = None
                # self.be_unconditional = True
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            assert config != '__is_first_stage__'
            assert config != '__is_unconditional__'
            model = instantiate_from_config(config)
            self.cond_stage_model = model

    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
        denoise_row = []
        for zd in tqdm(samples, desc=desc):
            denoise_row.append(self.decode_first_stage(zd.to(self.device),
                                                            force_not_quantize=force_no_decoder_quantization))
        n_imgs_per_row = len(denoise_row)
        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return self.scale_factor * z

    def get_learned_conditioning(self, c):
        if self.cond_stage_forward is None:
            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
                c = self.cond_stage_model.encode(c)
            else:
                c = self.cond_stage_model(c)
        else:
            assert hasattr(self.cond_stage_model, self.cond_stage_forward)
            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
        return c


    @torch.no_grad()
    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, bs=None):
        x = super().get_input(batch, k)
        if bs is not None:
            x = x[:bs]
        x = x.to(self.device)
        encoder_posterior = self.encode_first_stage(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()

        if self.model.conditioning_key is not None:
            if cond_key is None:
                cond_key = self.cond_stage_key
            if cond_key != self.first_stage_key:
                if cond_key in ['caption', 'coordinates_bbox']:
                    xc = batch[cond_key]
                elif cond_key == 'class_label':
                    xc = batch
                else:
                    xc = super().get_input(batch, cond_key).to(self.device)
            else:
                xc = x
            if not self.cond_stage_trainable or force_c_encode:
                if isinstance(xc, dict) or isinstance(xc, list):
                    # import pudb; pudb.set_trace()
                    c = self.get_learned_conditioning(xc)
                else:
                    c = self.get_learned_conditioning(xc.to(self.device))
            else:
                c = xc
            if bs is not None:
                c = c[:bs]

            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                ckey = __conditioning_keys__[self.model.conditioning_key]
                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}

        else:
            c = None
            xc = None
            if self.use_positional_encodings:
                pos_x, pos_y = self.compute_latent_shifts(batch)
                c = {'pos_x': pos_x, 'pos_y': pos_y}
        out = [z, c]
        if return_first_stage_outputs:
            xrec = self.decode_first_stage(z)
            out.extend([x, xrec])
        if return_original_cond:
            out.append(xc)
        return out

    @torch.no_grad()
    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
        if predict_cids:
            if z.dim() == 4:
                z = torch.argmax(z.exp(), dim=1).long()
            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
            z = rearrange(z, 'b h w c -> b c h w').contiguous()

        z = 1. / self.scale_factor * z

        return self.first_stage_model.decode(z)


    @torch.no_grad()
    def encode_first_stage(self, x):
        return self.first_stage_model.encode(x)

    def shared_step(self, batch, **kwargs):
        x, c = self.get_input(batch, self.first_stage_key)
        loss = self(x, c)
        return loss

    def forward(self, x, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)


    def apply_model(self, x_noisy, t, cond, return_ids=False):

        if isinstance(cond, dict):
            # hybrid case, cond is exptected to be a dict
            pass
        else:
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond}

        x_recon = self.model(x_noisy, t, **cond)

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon


    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        else:
            raise NotImplementedError()

        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        logvar_t = self.logvar[t].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict

    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
                        return_x0=False, score_corrector=None, corrector_kwargs=None):
        t_in = t
        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)

        if score_corrector is not None:
            assert self.parameterization == "eps"
            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)

        if return_codebook_ids:
            model_out, logits = model_out

        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        else:
            raise NotImplementedError()

        if clip_denoised:
            x_recon.clamp_(-1., 1.)
        if quantize_denoised:
            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        if return_codebook_ids:
            return model_mean, posterior_variance, posterior_log_variance, logits
        elif return_x0:
            return model_mean, posterior_variance, posterior_log_variance, x_recon
        else:
            return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        b, *_, device = *x.shape, x.device
        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
                                       return_codebook_ids=return_codebook_ids,
                                       quantize_denoised=quantize_denoised,
                                       return_x0=return_x0,
                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
        if return_codebook_ids:
            raise DeprecationWarning("Support dropped.")
            model_mean, _, model_log_variance, logits = outputs
        elif return_x0:
            model_mean, _, model_log_variance, x0 = outputs
        else:
            model_mean, _, model_log_variance = outputs

        noise = noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

        # if return_codebook_ids:
        #     return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
        if return_x0:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
        else:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
                              log_every_t=None):
        if not log_every_t:
            log_every_t = self.log_every_t
        timesteps = self.num_timesteps
        if batch_size is not None:
            b = batch_size if batch_size is not None else shape[0]
            shape = [batch_size] + list(shape)
        else:
            b = batch_size = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=self.device)
        else:
            img = x_T
        intermediates = []
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
                        total=timesteps) if verbose else reversed(
            range(0, timesteps))
        if type(temperature) == float:
            temperature = [temperature] * timesteps

        for i in iterator:
            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img, x0_partial = self.p_sample(img, cond, ts,
                                            clip_denoised=self.clip_denoised,
                                            quantize_denoised=quantize_denoised, return_x0=True,
                                            temperature=temperature[i], noise_dropout=noise_dropout,
                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
            if mask is not None:
                assert x0 is not None
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(x0_partial)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
        return img, intermediates

    @torch.no_grad()
    def p_sample_loop(self, cond, shape, return_intermediates=False,
                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, start_T=None,
                      log_every_t=None):

        if not log_every_t:
            log_every_t = self.log_every_t
        device = self.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        intermediates = [img]
        if timesteps is None:
            timesteps = self.num_timesteps

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
            range(0, timesteps))

        if mask is not None:
            assert x0 is not None
            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match

        for i in iterator:
            ts = torch.full((b,), i, device=device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img = self.p_sample(img, cond, ts,
                                clip_denoised=self.clip_denoised,
                                quantize_denoised=quantize_denoised)
            if mask is not None:
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(img)
            if callback: callback(i)
            if img_callback: img_callback(img, i)

        if return_intermediates:
            return img, intermediates
        return img

    @torch.no_grad()
    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
               verbose=True, timesteps=None, quantize_denoised=False,
               mask=None, x0=None, shape=None,**kwargs):
        if shape is None:
            shape = (batch_size, self.channels, self.image_size, self.image_size)
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
        return self.p_sample_loop(cond,
                                  shape,
                                  return_intermediates=return_intermediates, x_T=x_T,
                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
                                  mask=mask, x0=x0)

    @torch.no_grad()
    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):

        if ddim:
            ddim_sampler = DDIMSampler(self)
            shape = (self.channels, self.image_size, self.image_size)
            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
                                                        shape,cond,verbose=False,**kwargs)

        else:
            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
                                                 return_intermediates=True,**kwargs)

        return samples, intermediates


    @torch.no_grad()
    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
                   plot_diffusion_rows=True, **kwargs):

        use_ddim = ddim_steps is not None

        log = dict()
        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
                                           return_first_stage_outputs=True,
                                           force_c_encode=True,
                                           return_original_cond=True,
                                           bs=N)
        N = min(x.shape[0], N)
        n_row = min(x.shape[0], n_row)
        log["inputs"] = x
        log["reconstruction"] = xrec
        if self.model.conditioning_key is not None:
            if hasattr(self.cond_stage_model, "decode"):
                xc = self.cond_stage_model.decode(c)
                log["conditioning"] = xc
            elif self.cond_stage_key in ["caption"]:
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
                log["conditioning"] = xc
            elif self.cond_stage_key == 'class_label':
                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
                log['conditioning'] = xc
            elif isimage(xc):
                log["conditioning"] = xc
            if ismap(xc):
                log["original_conditioning"] = self.to_rgb(xc)

        if plot_diffusion_rows:
            # get diffusion row
            diffusion_row = list()
            z_start = z[:n_row]
            for t in range(self.num_timesteps):
                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
                    t = t.to(self.device).long()
                    noise = torch.randn_like(z_start)
                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
                    diffusion_row.append(self.decode_first_stage(z_noisy))

            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W
            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
            log["diffusion_row"] = diffusion_grid

        if sample:
            # get denoise row
            with self.ema_scope("Plotting"):
                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
                                                         ddim_steps=ddim_steps,eta=ddim_eta)
                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
            x_samples = self.decode_first_stage(samples)
            log["samples"] = x_samples
            if plot_denoise_rows:
                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
                log["denoise_row"] = denoise_grid

            if quantize_denoised:
                # also display when quantizing x0 while sampling
                with self.ema_scope("Plotting Quantized Denoised"):
                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
                                                             ddim_steps=ddim_steps,eta=ddim_eta,
                                                             quantize_denoised=True)
                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
                    #                                      quantize_denoised=True)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_x0_quantized"] = x_samples

            if inpaint:
                # make a simple center square
                b, h, w = z.shape[0], z.shape[2], z.shape[3]
                mask = torch.ones(N, h, w).to(self.device)
                # zeros will be filled in
                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
                mask = mask[:, None, ...]
                with self.ema_scope("Plotting Inpaint"):

                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_inpainting"] = x_samples
                log["mask"] = mask

                # outpaint
                with self.ema_scope("Plotting Outpaint"):
                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)
                x_samples = self.decode_first_stage(samples.to(self.device))
                log["samples_outpainting"] = x_samples

        if plot_progressive_rows:
            with self.ema_scope("Plotting Progressives"):
                img, progressives = self.progressive_denoising(c,
                                                               shape=(self.channels, self.image_size, self.image_size),
                                                               batch_size=N)
            prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
            log["progressive_row"] = prog_row

        if return_keys:
            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
                return log
            else:
                return {key: log[key] for key in return_keys}
        return log

    def configure_optimizers(self):
        lr = self.learning_rate
        params = list(self.model.parameters())
        if self.cond_stage_trainable:
            print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
            params = params + list(self.cond_stage_model.parameters())
        if self.learn_logvar:
            print('Diffusion model optimizing logvar')
            params.append(self.logvar)
        opt = torch.optim.AdamW(params, lr=lr)
        if self.use_scheduler:
            assert 'target' in self.scheduler_config
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }]
            return [opt], scheduler
        return opt

    @torch.no_grad()
    def to_rgb(self, x):
        x = x.float()
        if not hasattr(self, "colorize"):
            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
        x = nn.functional.conv2d(x, weight=self.colorize)
        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
        return x


class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'mcvd']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        elif self.conditioning_key == 'mcvd':
            out = self.diffusion_model(x, t, cond=c_concat[0])
        else:
            raise NotImplementedError()

        return out



#########################################################
#########################################################
#########################################################
#########################################################
#########################################################
#########################################################


class LatentDiffusionVFI(DDPM):
    """main class"""
    def __init__(self,
                 first_stage_config,
                 cond_stage_config,
                 num_timesteps_cond=None,
                 cond_stage_key="image",
                 cond_stage_trainable=False,
                 concat_mode=True,
                 cond_stage_forward=None,
                 conditioning_key=None,
                 scale_factor=1.0,
                 *args, **kwargs):
        self.num_timesteps_cond = default(num_timesteps_cond, 1)
        assert self.num_timesteps_cond <= kwargs['timesteps']
        # for backwards compatibility after implementation of DiffusionWrapper
        if conditioning_key is None:
            conditioning_key = 'concat' if concat_mode else 'crossattn'
        if cond_stage_config == '__is_unconditional__':
            conditioning_key = None
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        try:
            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
        except:
            self.num_downs = 0
        self.scale_factor = scale_factor

        self.instantiate_first_stage(first_stage_config)
        self.instantiate_cond_stage(cond_stage_config)
        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False

        self.restarted_from_ckpt = False
        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys)
            self.restarted_from_ckpt = True

    def make_cond_schedule(self, ):
        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
        self.cond_ids[:self.num_timesteps_cond] = ids


    def register_schedule(self,
                          given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)

        self.shorten_cond_schedule = self.num_timesteps_cond > 1
        if self.shorten_cond_schedule:
            self.make_cond_schedule()

    def instantiate_first_stage(self, config):
        model = instantiate_from_config(config)
        self.first_stage_model = model.eval()
        self.first_stage_model.train = disabled_train
        for param in self.first_stage_model.parameters():
            param.requires_grad = False

    def instantiate_cond_stage(self, config):
        if not self.cond_stage_trainable:
            if config == "__is_first_stage__":
                print("Using first stage also as cond stage.")
                self.cond_stage_model = self.first_stage_model
            elif config == "__is_unconditional__":
                print(f"Training {self.__class__.__name__} as an unconditional model.")
                self.cond_stage_model = None
                # self.be_unconditional = True
            else:
                model = instantiate_from_config(config)
                self.cond_stage_model = model.eval()
                self.cond_stage_model.train = disabled_train
                for param in self.cond_stage_model.parameters():
                    param.requires_grad = False
        else:
            assert config != '__is_first_stage__'
            assert config != '__is_unconditional__'
            model = instantiate_from_config(config)
            self.cond_stage_model = model

    def _get_denoise_row_from_list(self, samples, xc=None, phi_prev_list=None, phi_next_list=None, desc='', force_no_decoder_quantization=False):
        denoise_row = []
        for zd in tqdm(samples, desc=desc):
            denoise_row.append(self.decode_first_stage(zd.to(self.device), xc, phi_prev_list, phi_next_list,
                                        force_not_quantize=force_no_decoder_quantization))
 
        n_imgs_per_row = len(denoise_row)
        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W
        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
        return denoise_grid

    def get_first_stage_encoding(self, encoder_posterior):
        if isinstance(encoder_posterior, torch.Tensor):
            z = encoder_posterior
        else:
            raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
        return self.scale_factor * z

    def get_learned_conditioning(self, c):
        phi_prev_list, phi_next_list = None, None
        if isinstance(c, dict) and 'prev_frame' in c.keys():
            c_prev, phi_prev_list = self.cond_stage_model.encode(c['prev_frame'], ret_feature=True)
            c_next, phi_next_list = self.cond_stage_model.encode(c['next_frame'], ret_feature=True)
            c = torch.cat([c_prev, c_next], dim=1)
        else:
            c = self.cond_stage_model.encode(c)
        return c, phi_prev_list, phi_next_list


    @torch.no_grad()
    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
                  cond_key=None, return_original_cond=False, return_phi=False, bs=None):
        x = super().get_input(batch, k)
        if bs is not None:
            x = x[:bs]
        x = x.to(self.device)

        encoder_posterior = self.encode_first_stage(x)
        z = self.get_first_stage_encoding(encoder_posterior).detach()
        
        phi_prev_list, phi_next_list = None, None
        assert self.model.conditioning_key is not None
        if cond_key == None:
            cond_key = self.cond_stage_key
        assert cond_key == 'past_future_frames'
        xc = {'prev_frame': super().get_input(batch, 'prev_frame'),
              'next_frame': super().get_input(batch, 'next_frame')}

        if not self.cond_stage_trainable or force_c_encode:
            if isinstance(xc, dict) or isinstance(xc, list):
                # import pudb; pudb.set_trace()
                c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc)
            else:
                c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc.to(self.device))
        else:
            c = xc
        if bs is not None:
            c = c[:bs]
            if isinstance(xc, dict):
                xc['prev_frame'] = xc['prev_frame'][:bs]
                xc['next_frame'] = xc['next_frame'][:bs]
            if phi_prev_list and phi_next_list:
                phi_prev_list = [phi_prev[:bs] for phi_prev in phi_prev_list]
                phi_next_list = [phi_next[:bs] for phi_next in phi_next_list]


        out = [z, c]
        if return_first_stage_outputs:
            xrec = self.decode_first_stage(z, xc, phi_prev_list, phi_next_list)
            out.extend([x, xrec])
        if return_original_cond:
            out.append(xc)
        if return_phi:
            out.append(phi_prev_list)
            out.append(phi_next_list)
        return out

    @torch.no_grad()
    def decode_first_stage(self, z, xc=None, phi_prev_list=None, phi_next_list=None, predict_cids=False, force_not_quantize=False):
        if predict_cids:
            if z.dim() == 4:
                z = torch.argmax(z.exp(), dim=1).long()
            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
            z = rearrange(z, 'b h w c -> b c h w').contiguous()

        z = 1. / self.scale_factor * z

        return self.first_stage_model.decode(z, xc['prev_frame'], xc['next_frame'], phi_prev_list, phi_next_list, force_not_quantize=predict_cids or force_not_quantize)


    @torch.no_grad()
    def encode_first_stage(self, x):
        return self.first_stage_model.encode(x)

    def shared_step(self, batch, **kwargs):
        x, c = self.get_input(batch, self.first_stage_key)
        loss = self(x, c)
        return loss

    def forward(self, x, c, *args, **kwargs):
        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
        if self.model.conditioning_key is not None:
            assert c is not None
            if self.cond_stage_trainable:
                c, _, _ = self.get_learned_conditioning(c)
            if self.shorten_cond_schedule:  # TODO: drop this option
                tc = self.cond_ids[t].to(self.device)
                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
        return self.p_losses(x, c, t, *args, **kwargs)


    def apply_model(self, x_noisy, t, cond, return_ids=False):

        if isinstance(cond, dict):
            # hybrid case, cond is exptected to be a dict
            pass
        else:
            if not isinstance(cond, list):
                cond = [cond]
            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
            cond = {key: cond}
        
        x_recon = self.model(x_noisy, t, **cond)

        if isinstance(x_recon, tuple) and not return_ids:
            return x_recon[0]
        else:
            return x_recon


    def p_losses(self, x_start, cond, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
        model_output = self.apply_model(x_noisy, t, cond)

        loss_dict = {}
        prefix = 'train' if self.training else 'val'

        if self.parameterization == "x0":
            target = x_start
        elif self.parameterization == "eps":
            target = noise
        else:
            raise NotImplementedError()

        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

        logvar_t = self.logvar[t].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss, loss_dict

    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
                        return_x0=False, score_corrector=None, corrector_kwargs=None):
        t_in = t
        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)

        if score_corrector is not None:
            assert self.parameterization == "eps"
            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)

        if return_codebook_ids:
            model_out, logits = model_out

        if self.parameterization == "eps":
            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
        elif self.parameterization == "x0":
            x_recon = model_out
        else:
            raise NotImplementedError()

        if clip_denoised:
            x_recon.clamp_(-1., 1.)
        if quantize_denoised:
            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
        if return_codebook_ids:
            return model_mean, posterior_variance, posterior_log_variance, logits
        elif return_x0:
            return model_mean, posterior_variance, posterior_log_variance, x_recon
        else:
            return model_mean, posterior_variance, posterior_log_variance

    @torch.no_grad()
    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,
                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
        b, *_, device = *x.shape, x.device
        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
                                       return_codebook_ids=return_codebook_ids,
                                       quantize_denoised=quantize_denoised,
                                       return_x0=return_x0,
                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
        if return_codebook_ids:
            raise DeprecationWarning("Support dropped.")
            model_mean, _, model_log_variance, logits = outputs
        elif return_x0:
            model_mean, _, model_log_variance, x0 = outputs
        else:
            model_mean, _, model_log_variance = outputs

        noise = noise_like(x.shape, device, repeat_noise) * temperature
        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

        # if return_codebook_ids:
        #     return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
        if return_x0:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
        else:
            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

    @torch.no_grad()
    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
                              log_every_t=None):
        if not log_every_t:
            log_every_t = self.log_every_t
        timesteps = self.num_timesteps
        if batch_size is not None:
            b = batch_size if batch_size is not None else shape[0]
            shape = [batch_size] + list(shape)
        else:
            b = batch_size = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=self.device)
        else:
            img = x_T
        intermediates = []
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
                        total=timesteps) if verbose else reversed(
            range(0, timesteps))
        if type(temperature) == float:
            temperature = [temperature] * timesteps

        for i in iterator:
            ts = torch.full((b,), i, device=self.device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img, x0_partial = self.p_sample(img, cond, ts,
                                            clip_denoised=self.clip_denoised,
                                            quantize_denoised=quantize_denoised, return_x0=True,
                                            temperature=temperature[i], noise_dropout=noise_dropout,
                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
            if mask is not None:
                assert x0 is not None
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(x0_partial)
            if callback: callback(i)
            if img_callback: img_callback(img, i)
        return img, intermediates

    @torch.no_grad()
    def p_sample_loop(self, cond, shape, return_intermediates=False,
                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
                      mask=None, x0=None, img_callback=None, start_T=None,
                      log_every_t=None):

        if not log_every_t:
            log_every_t = self.log_every_t
        device = self.betas.device
        b = shape[0]
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        intermediates = [img]
        if timesteps is None:
            timesteps = self.num_timesteps

        if start_T is not None:
            timesteps = min(timesteps, start_T)
        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
            range(0, timesteps))

        if mask is not None:
            assert x0 is not None
            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match

        for i in iterator:
            ts = torch.full((b,), i, device=device, dtype=torch.long)
            if self.shorten_cond_schedule:
                assert self.model.conditioning_key != 'hybrid'
                tc = self.cond_ids[ts].to(cond.device)
                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))

            img = self.p_sample(img, cond, ts,
                                clip_denoised=self.clip_denoised,
                                quantize_denoised=quantize_denoised)
            if mask is not None:
                img_orig = self.q_sample(x0, ts)
                img = img_orig * mask + (1. - mask) * img

            if i % log_every_t == 0 or i == timesteps - 1:
                intermediates.append(img)
            if callback: callback(i)
            if img_callback: img_callback(img, i)

        if return_intermediates:
            return img, intermediates
        return img

    @torch.no_grad()
    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
               verbose=True, timesteps=None, quantize_denoised=False,
               mask=None, x0=None, shape=None,**kwargs):
        if shape is None:
            shape = (batch_size, self.channels, self.image_size, self.image_size)
        if cond is not None:
            if isinstance(cond, dict):
                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
                list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
            else:
                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
        return self.p_sample_loop(cond,
                                  shape,
                                  return_intermediates=return_intermediates, x_T=x_T,
                          
Download .txt
gitextract_grs9wzig/

├── .gitignore
├── LICENSE
├── README.md
├── configs/
│   ├── autoencoder/
│   │   └── vqflow-f32.yaml
│   └── ldm/
│       └── ldmvfi-vqflow-f32-c256-concat_max.yaml
├── cupy_module/
│   ├── __init__.py
│   └── dsepconv.py
├── environment.yaml
├── evaluate.py
├── evaluate_vqm.py
├── interpolate_yuv.py
├── ldm/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── bvi_vimeo.py
│   │   ├── testsets.py
│   │   ├── testsets_vqm.py
│   │   └── vfitransforms.py
│   ├── lr_scheduler.py
│   ├── models/
│   │   ├── autoencoder.py
│   │   └── diffusion/
│   │       ├── __init__.py
│   │       ├── ddim.py
│   │       └── ddpm.py
│   ├── modules/
│   │   ├── attention.py
│   │   ├── diffusionmodules/
│   │   │   ├── __init__.py
│   │   │   ├── model.py
│   │   │   ├── openaimodel.py
│   │   │   └── util.py
│   │   ├── ema.py
│   │   ├── losses/
│   │   │   ├── __init__.py
│   │   │   └── vqperceptual.py
│   │   └── maxvit.py
│   └── util.py
├── main.py
├── metrics/
│   ├── flolpips/
│   │   ├── .gitignore
│   │   ├── LICENSE
│   │   ├── README.md
│   │   ├── __init__.py
│   │   ├── correlation/
│   │   │   └── correlation.py
│   │   ├── flolpips.py
│   │   ├── pretrained_networks.py
│   │   ├── pwcnet.py
│   │   └── utils.py
│   ├── lpips/
│   │   ├── __init__.py
│   │   ├── lpips.py
│   │   └── pretrained_networks.py
│   └── pytorch_ssim/
│       └── __init__.py
├── setup.py
└── utility.py
Download .txt
SYMBOL INDEX (577 symbols across 31 files)

FILE: cupy_module/dsepconv.py
  class Stream (line 7) | class Stream:
  function cupy_kernel (line 451) | def cupy_kernel(strFunction, objectVariables):
  function cupy_launch (line 511) | def cupy_launch(strFunction, strKernel):
  class _FunctionDSepconv (line 518) | class _FunctionDSepconv(torch.autograd.Function):
    method forward (line 520) | def forward(self, input, vertical, horizontal, offset_x, offset_y, mask):
    method backward (line 571) | def backward(self, gradOutput):
  function FunctionDSepconv (line 701) | def FunctionDSepconv(tensorInput, tensorVertical, tensorHorizontal, tens...
  class ModuleDSepconv (line 707) | class ModuleDSepconv(torch.nn.Module):
    method __init__ (line 708) | def __init__(self):
    method forward (line 713) | def forward(self, tensorInput, tensorVertical, tensorHorizontal, tenso...

FILE: evaluate.py
  function main (line 26) | def main():

FILE: evaluate_vqm.py
  function main (line 16) | def main():

FILE: interpolate_yuv.py
  function main (line 31) | def main():

FILE: ldm/data/bvi_vimeo.py
  class Vimeo90k_triplet (line 11) | class Vimeo90k_triplet(Dataset):
    method __init__ (line 12) | def __init__(self, db_dir, train=True,  crop_sz=(256,256), augment_s=T...
    method __getitem__ (line 29) | def __getitem__(self, index):
    method __len__ (line 52) | def __len__(self):
  class Vimeo90k_quintuplet (line 56) | class Vimeo90k_quintuplet(Dataset):
    method __init__ (line 57) | def __init__(self, db_dir, train=True,  crop_sz=(256,256), augment_s=T...
    method __getitem__ (line 74) | def __getitem__(self, index):
    method __len__ (line 94) | def __len__(self):
  class BVIDVC_triplet (line 98) | class BVIDVC_triplet(Dataset):
    method __init__ (line 99) | def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True...
    method __getitem__ (line 107) | def __getitem__(self, index):
    method __len__ (line 133) | def __len__(self):
  class BVIDVC_quintuplet (line 137) | class BVIDVC_quintuplet(Dataset):
    method __init__ (line 138) | def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True...
    method __getitem__ (line 146) | def __getitem__(self, index):
    method __len__ (line 166) | def __len__(self):
  class Sampler (line 170) | class Sampler(Dataset):
    method __init__ (line 171) | def __init__(self, datasets, p_datasets=None, iter=False, samples_per_...
    method __getitem__ (line 186) | def __getitem__(self, index):
    method __len__ (line 199) | def __len__(self):
  class BVI_Vimeo_triplet (line 206) | class BVI_Vimeo_triplet(Dataset):
    method __init__ (line 207) | def __init__(self, db_dir, crop_sz=[256,256], p_datasets=None, iter=Fa...
    method __getitem__ (line 225) | def __getitem__(self, index):
    method __len__ (line 238) | def __len__(self):

FILE: ldm/data/testsets.py
  class TripletTestSet (line 17) | class TripletTestSet:
    method __init__ (line 18) | def __init__(self):
    method eval (line 21) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
  class Middlebury_others (line 80) | class Middlebury_others(TripletTestSet):
    method __init__ (line 81) | def __init__(self, db_dir):
  class Davis (line 93) | class Davis(TripletTestSet):
    method __init__ (line 94) | def __init__(self, db_dir):
  class Ucf (line 107) | class Ucf(TripletTestSet):
    method __init__ (line 108) | def __init__(self, db_dir):
  class Snufilm (line 121) | class Snufilm(TripletTestSet):
    method __init__ (line 122) | def __init__(self, db_dir, mode):
  class Snufilm_easy (line 139) | class Snufilm_easy(Snufilm):
    method __init__ (line 140) | def __init__(self, db_dir):
  class Snufilm_medium (line 144) | class Snufilm_medium(Snufilm):
    method __init__ (line 145) | def __init__(self, db_dir):
  class Snufilm_hard (line 149) | class Snufilm_hard(Snufilm):
    method __init__ (line 150) | def __init__(self, db_dir):
  class Snufilm_extreme (line 154) | class Snufilm_extreme(Snufilm):
    method __init__ (line 155) | def __init__(self, db_dir):
  class VFITex_triplet (line 159) | class VFITex_triplet:
    method __init__ (line 160) | def __init__(self, db_dir):
    method eval (line 166) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
  class Davis90_triplet (line 253) | class Davis90_triplet:
    method __init__ (line 254) | def __init__(self, db_dir):
    method eval (line 260) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...
  class Ucf101_triplet (line 341) | class Ucf101_triplet:
    method __init__ (line 342) | def __init__(self, db_dir):
    method eval (line 356) | def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_di...

FILE: ldm/data/testsets_vqm.py
  class TripletTestSet (line 16) | class TripletTestSet:
    method __init__ (line 17) | def __init__(self):
    method eval (line 20) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='out...
  class Middlebury_others (line 65) | class Middlebury_others(TripletTestSet):
    method __init__ (line 66) | def __init__(self, db_dir):
  class Davis (line 78) | class Davis(TripletTestSet):
    method __init__ (line 79) | def __init__(self, db_dir):
  class Ucf (line 92) | class Ucf(TripletTestSet):
    method __init__ (line 93) | def __init__(self, db_dir):
  class Snufilm (line 106) | class Snufilm(TripletTestSet):
    method __init__ (line 107) | def __init__(self, db_dir, mode):
  class Snufilm_easy (line 124) | class Snufilm_easy(Snufilm):
    method __init__ (line 125) | def __init__(self, db_dir):
  class Snufilm_medium (line 129) | class Snufilm_medium(Snufilm):
    method __init__ (line 130) | def __init__(self, db_dir):
  class Snufilm_hard (line 134) | class Snufilm_hard(Snufilm):
    method __init__ (line 135) | def __init__(self, db_dir):
  class Snufilm_extreme (line 139) | class Snufilm_extreme(Snufilm):
    method __init__ (line 140) | def __init__(self, db_dir):
  class VFITex_triplet (line 144) | class VFITex_triplet:
    method __init__ (line 145) | def __init__(self, db_dir):
    method eval (line 151) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None...
  class Davis90_triplet (line 223) | class Davis90_triplet:
    method __init__ (line 224) | def __init__(self, db_dir):
    method eval (line 230) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None...
  class Ucf101_triplet (line 296) | class Ucf101_triplet:
    method __init__ (line 297) | def __init__(self, db_dir):
    method eval (line 311) | def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='out...

FILE: ldm/data/vfitransforms.py
  function rand_crop (line 7) | def rand_crop(*args, sz):
  function rand_flip (line 15) | def rand_flip(*args, p):
  function rand_reverse (line 26) | def rand_reverse(*args, p):

FILE: ldm/lr_scheduler.py
  class LambdaWarmUpCosineScheduler (line 4) | class LambdaWarmUpCosineScheduler:
    method __init__ (line 8) | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_...
    method schedule (line 17) | def schedule(self, n, **kwargs):
    method __call__ (line 32) | def __call__(self, n, **kwargs):
  class LambdaWarmUpCosineScheduler2 (line 36) | class LambdaWarmUpCosineScheduler2:
    method __init__ (line 41) | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths...
    method find_in_interval (line 52) | def find_in_interval(self, n):
    method schedule (line 59) | def schedule(self, n, **kwargs):
    method __call__ (line 77) | def __call__(self, n, **kwargs):
  class LambdaLinearScheduler (line 81) | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
    method schedule (line 83) | def schedule(self, n, **kwargs):

FILE: ldm/models/autoencoder.py
  class VQFlowNet (line 18) | class VQFlowNet(pl.LightningModule):
    method __init__ (line 19) | def __init__(self,
    method ema_scope (line 74) | def ema_scope(self, context=None):
    method init_from_ckpt (line 88) | def init_from_ckpt(self, path, ignore_keys=list()):
    method on_train_batch_end (line 102) | def on_train_batch_end(self, *args, **kwargs):
    method encode (line 106) | def encode(self, x, ret_feature=False):
    method encode_to_prequant (line 145) | def encode_to_prequant(self, x):
    method decode (line 150) | def decode(self, quant, x_prev, x_next):
    method decode_code (line 167) | def decode_code(self, code_b):
    method forward (line 172) | def forward(self, input, x_prev, x_next, return_pred_indices=False):
    method get_input (line 180) | def get_input(self, batch, k):
    method training_step (line 198) | def training_step(self, batch, batch_idx, optimizer_idx):
    method validation_step (line 221) | def validation_step(self, batch, batch_idx):
    method _validation_step (line 227) | def _validation_step(self, batch, batch_idx, suffix=""):
    method configure_optimizers (line 254) | def configure_optimizers(self):
    method get_last_layer (line 287) | def get_last_layer(self):
    method log_images (line 290) | def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
    method to_rgb (line 314) | def to_rgb(self, x):
  class VQFlowNetInterface (line 323) | class VQFlowNetInterface(VQFlowNet):
    method __init__ (line 324) | def __init__(self, **kwargs):
    method encode (line 327) | def encode(self, x, ret_feature=False):
    method decode (line 366) | def decode(self, h, x_prev, x_next, phi_prev_list, phi_next_list, forc...

FILE: ldm/models/diffusion/ddim.py
  class DDIMSampler (line 11) | class DDIMSampler(object):
    method __init__ (line 12) | def __init__(self, model, schedule="linear", **kwargs):
    method register_buffer (line 18) | def register_buffer(self, name, attr):
    method make_schedule (line 24) | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddi...
    method sample (line 56) | def sample(self,
    method ddim_sampling (line 115) | def ddim_sampling(self, cond, shape,
    method p_sample_ddim (line 168) | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_origin...

FILE: ldm/models/diffusion/ddpm.py
  function disabled_train (line 33) | def disabled_train(self, mode=True):
  class DDPM (line 40) | class DDPM(pl.LightningModule):
    method __init__ (line 42) | def __init__(self,
    method register_schedule (line 113) | def register_schedule(self, given_betas=None, beta_schedule="linear", ...
    method ema_scope (line 168) | def ema_scope(self, context=None):
    method init_from_ckpt (line 182) | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
    method q_mean_variance (line 200) | def q_mean_variance(self, x_start, t):
    method predict_start_from_noise (line 212) | def predict_start_from_noise(self, x_t, t, noise):
    method q_posterior (line 218) | def q_posterior(self, x_start, x_t, t):
    method p_mean_variance (line 227) | def p_mean_variance(self, x, t, clip_denoised: bool):
    method p_sample (line 240) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    method p_sample_loop (line 249) | def p_sample_loop(self, shape, return_intermediates=False):
    method sample (line 264) | def sample(self, batch_size=16, return_intermediates=False):
    method q_sample (line 270) | def q_sample(self, x_start, t, noise=None):
    method get_loss (line 275) | def get_loss(self, pred, target, mean=True):
    method p_losses (line 290) | def p_losses(self, x_start, t, noise=None):
    method forward (line 319) | def forward(self, x, *args, **kwargs):
    method get_input (line 325) | def get_input(self, batch, k):
    method shared_step (line 333) | def shared_step(self, batch):
    method training_step (line 338) | def training_step(self, batch, batch_idx):
    method validation_step (line 354) | def validation_step(self, batch, batch_idx):
    method on_train_batch_end (line 362) | def on_train_batch_end(self, *args, **kwargs):
    method _get_rows_from_list (line 366) | def _get_rows_from_list(self, samples):
    method log_images (line 374) | def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=Non...
    method configure_optimizers (line 411) | def configure_optimizers(self):
  class LatentDiffusion (line 420) | class LatentDiffusion(DDPM):
    method __init__ (line 422) | def __init__(self,
    method make_cond_schedule (line 462) | def make_cond_schedule(self, ):
    method register_schedule (line 468) | def register_schedule(self,
    method instantiate_first_stage (line 477) | def instantiate_first_stage(self, config):
    method instantiate_cond_stage (line 484) | def instantiate_cond_stage(self, config):
    method _get_denoise_row_from_list (line 505) | def _get_denoise_row_from_list(self, samples, desc='', force_no_decode...
    method get_first_stage_encoding (line 517) | def get_first_stage_encoding(self, encoder_posterior):
    method get_learned_conditioning (line 524) | def get_learned_conditioning(self, c):
    method get_input (line 537) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
    method decode_first_stage (line 589) | def decode_first_stage(self, z, predict_cids=False, force_not_quantize...
    method encode_first_stage (line 602) | def encode_first_stage(self, x):
    method shared_step (line 605) | def shared_step(self, batch, **kwargs):
    method forward (line 610) | def forward(self, x, c, *args, **kwargs):
    method apply_model (line 622) | def apply_model(self, x_noisy, t, cond, return_ids=False):
    method p_losses (line 641) | def p_losses(self, x_start, cond, t, noise=None):
    method p_mean_variance (line 676) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
    method p_sample (line 708) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
    method progressive_denoising (line 739) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
    method p_sample_loop (line 795) | def p_sample_loop(self, cond, shape, return_intermediates=False,
    method sample (line 846) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
    method sample_log (line 864) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
    method log_images (line 880) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
    method configure_optimizers (line 989) | def configure_optimizers(self):
    method to_rgb (line 1014) | def to_rgb(self, x):
  class DiffusionWrapper (line 1023) | class DiffusionWrapper(pl.LightningModule):
    method __init__ (line 1024) | def __init__(self, diff_model_config, conditioning_key):
    method forward (line 1030) | def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):
  class LatentDiffusionVFI (line 1063) | class LatentDiffusionVFI(DDPM):
    method __init__ (line 1065) | def __init__(self,
    method make_cond_schedule (line 1105) | def make_cond_schedule(self, ):
    method register_schedule (line 1111) | def register_schedule(self,
    method instantiate_first_stage (line 1120) | def instantiate_first_stage(self, config):
    method instantiate_cond_stage (line 1127) | def instantiate_cond_stage(self, config):
    method _get_denoise_row_from_list (line 1148) | def _get_denoise_row_from_list(self, samples, xc=None, phi_prev_list=N...
    method get_first_stage_encoding (line 1161) | def get_first_stage_encoding(self, encoder_posterior):
    method get_learned_conditioning (line 1168) | def get_learned_conditioning(self, c):
    method get_input (line 1180) | def get_input(self, batch, k, return_first_stage_outputs=False, force_...
    method decode_first_stage (line 1228) | def decode_first_stage(self, z, xc=None, phi_prev_list=None, phi_next_...
    method encode_first_stage (line 1241) | def encode_first_stage(self, x):
    method shared_step (line 1244) | def shared_step(self, batch, **kwargs):
    method forward (line 1249) | def forward(self, x, c, *args, **kwargs):
    method apply_model (line 1261) | def apply_model(self, x_noisy, t, cond, return_ids=False):
    method p_losses (line 1280) | def p_losses(self, x_start, cond, t, noise=None):
    method p_mean_variance (line 1315) | def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codeboo...
    method p_sample (line 1347) | def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
    method progressive_denoising (line 1378) | def progressive_denoising(self, cond, shape, verbose=True, callback=No...
    method p_sample_loop (line 1434) | def p_sample_loop(self, cond, shape, return_intermediates=False,
    method sample (line 1485) | def sample(self, cond, batch_size=16, return_intermediates=False, x_T=...
    method sample_ddpm (line 1502) | def sample_ddpm(self, conditioning, batch_size=16, return_intermediate...
    method sample_log (line 1522) | def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
    method log_images (line 1538) | def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200,...
    method configure_optimizers (line 1620) | def configure_optimizers(self):
    method to_rgb (line 1645) | def to_rgb(self, x):

FILE: ldm/modules/attention.py
  function exists (line 11) | def exists(val):
  function uniq (line 15) | def uniq(arr):
  function default (line 19) | def default(val, d):
  function max_neg_value (line 25) | def max_neg_value(t):
  function init_ (line 29) | def init_(tensor):
  class GEGLU (line 37) | class GEGLU(nn.Module):
    method __init__ (line 38) | def __init__(self, dim_in, dim_out):
    method forward (line 42) | def forward(self, x):
  class FeedForward (line 47) | class FeedForward(nn.Module):
    method __init__ (line 48) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
    method forward (line 63) | def forward(self, x):
  function zero_module (line 67) | def zero_module(module):
  function Normalize (line 76) | def Normalize(in_channels):
  class LinearAttention (line 80) | class LinearAttention(nn.Module):
    method __init__ (line 81) | def __init__(self, dim, heads=4, dim_head=32):
    method forward (line 88) | def forward(self, x):
  class SpatialSelfAttention (line 99) | class SpatialSelfAttention(nn.Module):
    method __init__ (line 100) | def __init__(self, in_channels):
    method forward (line 126) | def forward(self, x):
  class CrossAttention (line 152) | class CrossAttention(nn.Module):
    method __init__ (line 157) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 174) | def forward(self, x, context=None, mask=None):
  class SpatialCrossAttention (line 200) | class SpatialCrossAttention(nn.Module):
    method __init__ (line 207) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
    method forward (line 226) | def forward(self, x, context=None):
  function posemb_sincos_2d (line 259) | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
  class SpatialCrossAttentionWithPosEmb (line 276) | class SpatialCrossAttentionWithPosEmb(nn.Module):
    method __init__ (line 283) | def __init__(self, in_channels=None, heads=8, dim_head=64, dropout=0.):
    method forward (line 312) | def forward(self, x, context=None):
  class BasicTransformerBlock (line 361) | class BasicTransformerBlock(nn.Module):
    method __init__ (line 369) | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None,...
    method forward (line 380) | def forward(self, x, context=None):
    method _forward (line 383) | def _forward(self, x, context=None):
  class SpatialTransformer (line 390) | class SpatialTransformer(nn.Module):
    method __init__ (line 399) | def __init__(self, in_channels, n_heads, d_head,
    method forward (line 423) | def forward(self, x, context=None):

FILE: ldm/modules/diffusionmodules/model.py
  function get_timestep_embedding (line 13) | def get_timestep_embedding(timesteps, embedding_dim):
  function nonlinearity (line 34) | def nonlinearity(x):
  function Normalize (line 39) | def Normalize(in_channels, num_groups=32):
  class IdentityWrapper (line 43) | class IdentityWrapper(nn.Module):
    method __init__ (line 47) | def __init__(self) -> None:
    method forward (line 51) | def forward(self, x, context=None):
  class Upsample (line 56) | class Upsample(nn.Module):
    method __init__ (line 57) | def __init__(self, in_channels, with_conv):
    method forward (line 67) | def forward(self, x):
  class Downsample (line 74) | class Downsample(nn.Module):
    method __init__ (line 75) | def __init__(self, in_channels, with_conv):
    method forward (line 86) | def forward(self, x):
  class ResnetBlock (line 96) | class ResnetBlock(nn.Module):
    method __init__ (line 97) | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=Fa...
    method forward (line 135) | def forward(self, x, temb):
  class LinAttnBlock (line 158) | class LinAttnBlock(LinearAttention):
    method __init__ (line 160) | def __init__(self, in_channels):
  class AttnBlock (line 164) | class AttnBlock(nn.Module):
    method __init__ (line 165) | def __init__(self, in_channels):
    method forward (line 192) | def forward(self, x):
  function make_attn (line 219) | def make_attn(in_channels, attn_type="vanilla"):
  class FIEncoder (line 234) | class FIEncoder(nn.Module):
    method __init__ (line 235) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 300) | def forward(self, x, ret_feature=False):
  class FlowEncoder (line 333) | class FlowEncoder(FIEncoder):
    method __init__ (line 334) | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks...
  class FlowDecoderWithResidual (line 354) | class FlowDecoderWithResidual(nn.Module):
    method __init__ (line 355) | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
    method forward (line 520) | def forward(self, z, cond_dict):

FILE: ldm/modules/diffusionmodules/openaimodel.py
  function convert_module_to_f16 (line 26) | def convert_module_to_f16(x):
  function convert_module_to_f32 (line 29) | def convert_module_to_f32(x):
  class AttentionPool2d (line 34) | class AttentionPool2d(nn.Module):
    method __init__ (line 39) | def __init__(
    method forward (line 53) | def forward(self, x):
  class TimestepBlock (line 64) | class TimestepBlock(nn.Module):
    method forward (line 70) | def forward(self, x, emb):
  class TimestepEmbedSequential (line 76) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    method forward (line 82) | def forward(self, x, emb, context=None):
  class Upsample (line 93) | class Upsample(nn.Module):
    method __init__ (line 102) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
    method forward (line 111) | def forward(self, x):
  class TransposedUpsample (line 123) | class TransposedUpsample(nn.Module):
    method __init__ (line 125) | def __init__(self, channels, out_channels=None, ks=5):
    method forward (line 132) | def forward(self,x):
  class Downsample (line 136) | class Downsample(nn.Module):
    method __init__ (line 145) | def __init__(self, channels, use_conv, dims=2, out_channels=None,paddi...
    method forward (line 160) | def forward(self, x):
  class ResBlock (line 165) | class ResBlock(TimestepBlock):
    method __init__ (line 181) | def __init__(
    method forward (line 245) | def forward(self, x, emb):
    method _forward (line 257) | def _forward(self, x, emb):
  class AttentionBlock (line 280) | class AttentionBlock(nn.Module):
    method __init__ (line 287) | def __init__(
    method forward (line 316) | def forward(self, x):
    method _forward (line 320) | def _forward(self, x):
  function count_flops_attn (line 329) | def count_flops_attn(model, _x, y):
  class QKVAttentionLegacy (line 349) | class QKVAttentionLegacy(nn.Module):
    method __init__ (line 354) | def __init__(self, n_heads):
    method forward (line 358) | def forward(self, qkv):
    method count_flops (line 377) | def count_flops(model, _x, y):
  class QKVAttention (line 381) | class QKVAttention(nn.Module):
    method __init__ (line 386) | def __init__(self, n_heads):
    method forward (line 390) | def forward(self, qkv):
    method count_flops (line 411) | def count_flops(model, _x, y):
  class UNetModel (line 415) | class UNetModel(nn.Module):
    method __init__ (line 445) | def __init__(
    method convert_to_fp16 (line 711) | def convert_to_fp16(self):
    method convert_to_fp32 (line 719) | def convert_to_fp32(self):
    method forward (line 727) | def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
  class EncoderUNetModel (line 762) | class EncoderUNetModel(nn.Module):
    method __init__ (line 768) | def __init__(
    method convert_to_fp16 (line 941) | def convert_to_fp16(self):
    method convert_to_fp32 (line 948) | def convert_to_fp32(self):
    method forward (line 955) | def forward(self, x, timesteps):

FILE: ldm/modules/diffusionmodules/util.py
  function make_beta_schedule (line 21) | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_e...
  function make_ddim_timesteps (line 46) | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_...
  function make_ddim_sampling_parameters (line 63) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
  function betas_for_alpha_bar (line 77) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
  function extract_into_tensor (line 96) | def extract_into_tensor(a, t, x_shape):
  function checkpoint (line 102) | def checkpoint(func, inputs, params, flag):
  class CheckpointFunction (line 119) | class CheckpointFunction(torch.autograd.Function):
    method forward (line 121) | def forward(ctx, run_function, length, *args):
    method backward (line 131) | def backward(ctx, *output_grads):
  function timestep_embedding (line 151) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  function zero_module (line 174) | def zero_module(module):
  function scale_module (line 183) | def scale_module(module, scale):
  function mean_flat (line 192) | def mean_flat(tensor):
  function normalization (line 199) | def normalization(channels):
  class SiLU (line 209) | class SiLU(nn.Module):
    method forward (line 210) | def forward(self, x):
  class GroupNorm32 (line 214) | class GroupNorm32(nn.GroupNorm):
    method forward (line 215) | def forward(self, x):
  function conv_nd (line 218) | def conv_nd(dims, *args, **kwargs):
  function linear (line 231) | def linear(*args, **kwargs):
  function avg_pool_nd (line 238) | def avg_pool_nd(dims, *args, **kwargs):
  class HybridConditioner (line 251) | class HybridConditioner(nn.Module):
    method __init__ (line 253) | def __init__(self, c_concat_config, c_crossattn_config):
    method forward (line 258) | def forward(self, c_concat, c_crossattn):
  function noise_like (line 264) | def noise_like(shape, device, repeat=False):

FILE: ldm/modules/ema.py
  class LitEma (line 5) | class LitEma(nn.Module):
    method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
    method forward (line 25) | def forward(self,model):
    method copy_to (line 46) | def copy_to(self, model):
    method store (line 55) | def store(self, parameters):
    method restore (line 64) | def restore(self, parameters):

FILE: ldm/modules/losses/vqperceptual.py
  function hinge_d_loss_with_exemplar_weights (line 11) | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
  function adopt_weight (line 20) | def adopt_weight(weight, global_step, threshold=0, value=0.):
  function measure_perplexity (line 26) | def measure_perplexity(predicted_indices, n_embed):
  function l1 (line 35) | def l1(x, y):
  function l2 (line 39) | def l2(x, y):
  class VQLPIPSWithDiscriminator (line 43) | class VQLPIPSWithDiscriminator(nn.Module):
    method __init__ (line 44) | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
    method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
    method forward (line 98) | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,

FILE: ldm/modules/maxvit.py
  function exists (line 12) | def exists(val):
  function default (line 16) | def default(val, d):
  class PreNormResidual (line 22) | class PreNormResidual(nn.Module):
    method __init__ (line 23) | def __init__(self, dim, fn):
    method forward (line 28) | def forward(self, x, c=None):
  class SqueezeExcitation (line 34) | class SqueezeExcitation(nn.Module):
    method __init__ (line 35) | def __init__(self, dim, shrinkage_rate = 0.25):
    method forward (line 48) | def forward(self, x):
  class FeedForward (line 52) | class FeedForward(nn.Module):
    method __init__ (line 53) | def __init__(self, dim, mult = 4, dropout = 0.):
    method forward (line 63) | def forward(self, x):
  class Attention (line 67) | class Attention(nn.Module):
    method __init__ (line 68) | def __init__(
    method forward (line 108) | def forward(self, x, c=None):
  class Dropsample (line 158) | class Dropsample(nn.Module):
    method __init__ (line 159) | def __init__(self, prob = 0):
    method forward (line 163) | def forward(self, x):
  class MBConvResidual (line 173) | class MBConvResidual(nn.Module):
    method __init__ (line 174) | def __init__(self, fn, dropout = 0.):
    method forward (line 179) | def forward(self, x):
  function MBConv (line 185) | def MBConv(
  class MaxAttentionBlock (line 215) | class MaxAttentionBlock(nn.Module):
    method __init__ (line 216) | def __init__(self, in_channels, heads=8, dim_head=64, dropout=0., wind...
    method forward (line 232) | def forward(self, x):
  class SpatialCrossAttentionWithMax (line 249) | class SpatialCrossAttentionWithMax(nn.Module):
    method __init__ (line 250) | def __init__(self, in_channels, heads=8, dim_head=64, ctx_dim=None, dr...
    method forward (line 274) | def forward(self, x, context=None):
  class SpatialTransformerWithMax (line 298) | class SpatialTransformerWithMax(nn.Module):
    method __init__ (line 307) | def __init__(self, in_channels, n_heads, d_head, dropout=0., context_d...
    method forward (line 325) | def forward(self, x, context=None):

FILE: ldm/util.py
  function log_txt_as_img (line 17) | def log_txt_as_img(wh, xc, size=10):
  function ismap (line 41) | def ismap(x):
  function isimage (line 47) | def isimage(x):
  function exists (line 53) | def exists(x):
  function default (line 57) | def default(val, d):
  function mean_flat (line 63) | def mean_flat(tensor):
  function count_params (line 71) | def count_params(model, verbose=False):
  function instantiate_from_config (line 78) | def instantiate_from_config(config):
  function get_obj_from_str (line 88) | def get_obj_from_str(string, reload=False):
  function _do_parallel_data_prefetch (line 96) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
  function parallel_data_prefetch (line 108) | def parallel_data_prefetch(

FILE: main.py
  function get_parser (line 23) | def get_parser(**parser_kwargs):
  function nondefault_trainer_args (line 125) | def nondefault_trainer_args(opt):
  class WrappedDataset (line 132) | class WrappedDataset(Dataset):
    method __init__ (line 135) | def __init__(self, dataset):
    method __len__ (line 138) | def __len__(self):
    method __getitem__ (line 141) | def __getitem__(self, idx):
  function worker_init_fn (line 145) | def worker_init_fn(_):
  class DataModuleFromConfig (line 154) | class DataModuleFromConfig(pl.LightningDataModule):
    method __init__ (line 155) | def __init__(self, batch_size, train=None, validation=None, test=None,...
    method prepare_data (line 177) | def prepare_data(self):
    method setup (line 181) | def setup(self, stage=None):
    method _train_dataloader (line 189) | def _train_dataloader(self):
    method _val_dataloader (line 198) | def _val_dataloader(self, shuffle=False):
    method _test_dataloader (line 209) | def _test_dataloader(self, shuffle=False):
    method _predict_dataloader (line 218) | def _predict_dataloader(self, shuffle=False):
  class SetupCallback (line 227) | class SetupCallback(Callback):
    method __init__ (line 228) | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, light...
    method on_keyboard_interrupt (line 238) | def on_keyboard_interrupt(self, trainer, pl_module):
    method on_fit_start (line 244) | def on_fit_start(self, trainer, pl_module):
  class ImageLogger (line 276) | class ImageLogger(Callback):
    method __init__ (line 277) | def __init__(self, batch_frequency, val_batch_frequency, max_images, c...
    method _testtube (line 299) | def _testtube(self, pl_module, images, batch_idx, split):
    method log_local (line 310) | def log_local(self, save_dir, split, images,
    method log_img (line 329) | def log_img(self, pl_module, batch, batch_idx, split="train"):
    method check_frequency (line 382) | def check_frequency(self, check_idx):
    method on_train_batch_end (line 393) | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch...
    method on_validation_batch_end (line 397) | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, ...
    method on_validation_epoch_end (line 404) | def on_validation_epoch_end(self, trainer, pl_module):
  class CUDACallback (line 410) | class CUDACallback(Callback):
    method on_train_epoch_start (line 412) | def on_train_epoch_start(self, trainer, pl_module):
    method on_train_epoch_end (line 418) | def on_train_epoch_end(self, trainer, pl_module, outputs=None):
  function melk (line 681) | def melk(*args, **kwargs):
  function divein (line 689) | def divein(*args, **kwargs):

FILE: metrics/flolpips/correlation/correlation.py
  function cupy_kernel (line 235) | def cupy_kernel(strFunction, objVariables):
  function cupy_launch (line 274) | def cupy_launch(strFunction, strKernel):
  class _FunctionCorrelation (line 278) | class _FunctionCorrelation(torch.autograd.Function):
    method forward (line 280) | def forward(self, first, second):
    method backward (line 333) | def backward(self, gradOutput):
  function FunctionCorrelation (line 385) | def FunctionCorrelation(tenFirst, tenSecond):
  class ModuleCorrelation (line 389) | class ModuleCorrelation(torch.nn.Module):
    method __init__ (line 390) | def __init__(self):
    method forward (line 394) | def forward(self, tenFirst, tenSecond):

FILE: metrics/flolpips/flolpips.py
  function spatial_average (line 17) | def spatial_average(in_tens, keepdim=True):
  function mw_spatial_average (line 20) | def mw_spatial_average(in_tens, flow, keepdim=True):
  function mtw_spatial_average (line 28) | def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
  function m2w_spatial_average (line 41) | def m2w_spatial_average(in_tens, flow, keepdim=True):
  function upsample (line 48) | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same fo...
  class LPIPS (line 53) | class LPIPS(nn.Module):
    method __init__ (line 54) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
    method forward (line 111) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
  class ScalingLayer (line 157) | class ScalingLayer(nn.Module):
    method __init__ (line 158) | def __init__(self):
    method forward (line 163) | def forward(self, inp):
  class NetLinLayer (line 167) | class NetLinLayer(nn.Module):
    method __init__ (line 169) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
    method forward (line 176) | def forward(self, x):
  class Dist2LogitLayer (line 179) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 181) | def __init__(self, chn_mid=32, use_sigmoid=True):
    method forward (line 193) | def forward(self,d0,d1,eps=0.1):
  class BCERankingLoss (line 196) | class BCERankingLoss(nn.Module):
    method __init__ (line 197) | def __init__(self, chn_mid=32):
    method forward (line 203) | def forward(self, d0, d1, judge):
  class FakeNet (line 209) | class FakeNet(nn.Module):
    method __init__ (line 210) | def __init__(self, use_gpu=True, colorspace='Lab'):
  class L2 (line 215) | class L2(FakeNet):
    method forward (line 216) | def forward(self, in0, in1, retPerLayer=None):
  class DSSIM (line 231) | class DSSIM(FakeNet):
    method forward (line 233) | def forward(self, in0, in1, retPerLayer=None):
  function print_network (line 246) | def print_network(net):
  class FloLPIPS (line 254) | class FloLPIPS(LPIPS):
    method __init__ (line 255) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
    method forward (line 258) | def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
  function calc_flolpips (line 278) | def calc_flolpips(dis_path, ref_path):

FILE: metrics/flolpips/pretrained_networks.py
  class squeezenet (line 5) | class squeezenet(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 35) | def forward(self, X):
  class alexnet (line 56) | class alexnet(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 80) | def forward(self, X):
  class vgg16 (line 96) | class vgg16(torch.nn.Module):
    method __init__ (line 97) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 120) | def forward(self, X):
  class resnet (line 138) | class resnet(torch.nn.Module):
    method __init__ (line 139) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 162) | def forward(self, X):

FILE: metrics/flolpips/pwcnet.py
  function backwarp (line 45) | def backwarp(tenInput, tenFlow):
  class Network (line 71) | class Network(torch.nn.Module):
    method __init__ (line 72) | def __init__(self):
    method forward (line 263) | def forward(self, tenFirst, tenSecond, *args):
    method extract_pyramid_single (line 295) | def extract_pyramid_single(self, tenFirst):
  function estimate (line 310) | def estimate(tenFirst, tenSecond):

FILE: metrics/flolpips/utils.py
  function normalize_tensor (line 6) | def normalize_tensor(in_feat,eps=1e-10):
  function l2 (line 10) | def l2(p0, p1, range=255.):
  function dssim (line 13) | def dssim(p0, p1, range=255.):
  function tensor2im (line 17) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
  function tensor2np (line 22) | def tensor2np(tensor_obj):
  function np2tensor (line 26) | def np2tensor(np_obj):
  function tensor2tensorlab (line 30) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
  function read_frame_yuv2rgb (line 44) | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt...

FILE: metrics/lpips/__init__.py
  function normalize_tensor (line 42) | def normalize_tensor(in_feat,eps=1e-10):
  function l2 (line 46) | def l2(p0, p1, range=255.):
  function psnr (line 49) | def psnr(p0, p1, peak=255.):
  function dssim (line 52) | def dssim(p0, p1, range=255.):
  function rgb2lab (line 56) | def rgb2lab(in_img,mean_cent=False):
  function tensor2np (line 63) | def tensor2np(tensor_obj):
  function np2tensor (line 67) | def np2tensor(np_obj):
  function tensor2tensorlab (line 71) | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
  function tensorlab2tensor (line 85) | def tensorlab2tensor(lab_tensor,return_inbnd=False):
  function load_image (line 103) | def load_image(path):
  function rgb2lab (line 116) | def rgb2lab(input):
  function tensor2im (line 120) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
  function im2tensor (line 125) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
  function tensor2vec (line 129) | def tensor2vec(vector_tensor):
  function tensor2im (line 133) | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
  function im2tensor (line 139) | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
  function voc_ap (line 146) | def voc_ap(rec, prec, use_07_metric=False):

FILE: metrics/lpips/lpips.py
  function spatial_average (line 14) | def spatial_average(in_tens, keepdim=True):
  function upsample (line 22) | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same fo...
  class LPIPS (line 27) | class LPIPS(nn.Module):
    method __init__ (line 28) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
    method forward (line 85) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
  class ScalingLayer (line 132) | class ScalingLayer(nn.Module):
    method __init__ (line 133) | def __init__(self):
    method forward (line 138) | def forward(self, inp):
  class NetLinLayer (line 142) | class NetLinLayer(nn.Module):
    method __init__ (line 144) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
    method forward (line 151) | def forward(self, x):
  class Dist2LogitLayer (line 154) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 156) | def __init__(self, chn_mid=32, use_sigmoid=True):
    method forward (line 168) | def forward(self,d0,d1,eps=0.1):
  class BCERankingLoss (line 171) | class BCERankingLoss(nn.Module):
    method __init__ (line 172) | def __init__(self, chn_mid=32):
    method forward (line 178) | def forward(self, d0, d1, judge):
  class FakeNet (line 184) | class FakeNet(nn.Module):
    method __init__ (line 185) | def __init__(self, use_gpu=True, colorspace='Lab'):
  class L2 (line 190) | class L2(FakeNet):
    method forward (line 191) | def forward(self, in0, in1, retPerLayer=None):
  class DSSIM (line 206) | class DSSIM(FakeNet):
    method forward (line 208) | def forward(self, in0, in1, retPerLayer=None):
  function print_network (line 221) | def print_network(net):

FILE: metrics/lpips/pretrained_networks.py
  class squeezenet (line 5) | class squeezenet(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 35) | def forward(self, X):
  class alexnet (line 56) | class alexnet(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 80) | def forward(self, X):
  class vgg16 (line 96) | class vgg16(torch.nn.Module):
    method __init__ (line 97) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 120) | def forward(self, X):
  class resnet (line 138) | class resnet(torch.nn.Module):
    method __init__ (line 139) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 162) | def forward(self, X):

FILE: metrics/pytorch_ssim/__init__.py
  function gaussian (line 7) | def gaussian(window_size, sigma):
  function create_window (line 11) | def create_window(window_size, channel):
  function create_window_3d (line 18) | def create_window_3d(window_size, channel=1):
  function _ssim (line 26) | def _ssim(img1, img2, window, window_size, channel, size_average = True):
  function ssim_matlab (line 49) | def ssim_matlab(img1, img2, window_size=11, window=None, size_average=Tr...
  class SSIM (line 105) | class SSIM(torch.nn.Module):
    method __init__ (line 106) | def __init__(self, window_size = 11, size_average = True):
    method forward (line 113) | def forward(self, img1, img2):
  function ssim (line 131) | def ssim(img1, img2, window_size = 11, size_average = True):

FILE: utility.py
  function read_frame_yuv2rgb (line 9) | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt...
  function CharbonnierFunc (line 65) | def CharbonnierFunc(data, epsilon=0.001):
  function moduleNormalize (line 69) | def moduleNormalize(frame):
  function gaussian_kernel (line 73) | def gaussian_kernel(sz, sigma):
  function quantize (line 81) | def quantize(imTensor):
  function tensor2rgb (line 85) | def tensor2rgb(tensor):
  function calc_psnr (line 95) | def calc_psnr(gt, out, *args):
  function calc_ssim (line 104) | def calc_ssim(gt, out, *args):
  function calc_lpips (line 108) | def calc_lpips(gt, out, *args):
  function calc_flolpips (line 114) | def calc_flolpips(gt_list, out_list, inputs_list):
Condensed preview — 47 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (453K chars).
[
  {
    "path": ".gitignore",
    "chars": 101,
    "preview": "*.pth\n*.ckpt\n*__pycache__*\n*.pyc\n*egg*\n*src/*\n*.ipynb\nlogs/*\n*delete*\neval_results*\n*.idea*\n*.pytorch"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2023 danielism97\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 5798,
    "preview": "# LDMVFI: Video Frame Interpolation with Latent Diffusion Models\n\n[**Duolikun Danier**](https://danier97.github.io/), [*"
  },
  {
    "path": "configs/autoencoder/vqflow-f32.yaml",
    "chars": 1472,
    "preview": "model:\n  base_learning_rate: 1.0e-5\n  target: ldm.models.autoencoder.VQFlowNet\n  params:\n    monitor: val/total_loss\n   "
  },
  {
    "path": "configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml",
    "chars": 2659,
    "preview": "model:\n  base_learning_rate: 1.0e-06\n  target: ldm.models.diffusion.ddpm.LatentDiffusionVFI\n  params:\n    linear_start: "
  },
  {
    "path": "cupy_module/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cupy_module/dsepconv.py",
    "chars": 32355,
    "preview": "import torch\n\nimport cupy\nimport re\n\n\nclass Stream:\n    ptr = torch.cuda.current_stream().cuda_stream\n\n\n# end\n\nkernel_DS"
  },
  {
    "path": "environment.yaml",
    "chars": 1216,
    "preview": "name: ldmvfi\nchannels:\n  - pytorch\n  - defaults\n  - conda-forge\ndependencies:\n  - python=3.9.13\n  - pytorch=1.11.0\n  - t"
  },
  {
    "path": "evaluate.py",
    "chars": 2387,
    "preview": "import argparse\nimport os\nimport torch\nfrom functools import partial\nfrom omegaconf import OmegaConf\nfrom main import in"
  },
  {
    "path": "evaluate_vqm.py",
    "chars": 1386,
    "preview": "import argparse\nimport os\nfrom ldm.data import testsets_vqm\n\n\nparser = argparse.ArgumentParser(description='Frame Interp"
  },
  {
    "path": "interpolate_yuv.py",
    "chars": 5038,
    "preview": "import argparse\nimport torch\nimport torchvision.transforms.functional as TF\nimport os\nfrom PIL import Image\nfrom tqdm im"
  },
  {
    "path": "ldm/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ldm/data/bvi_vimeo.py",
    "chars": 9439,
    "preview": "import numpy as np\nimport random\nfrom os import listdir\nfrom os.path import join, isdir, split, getsize\nfrom torch.utils"
  },
  {
    "path": "ldm/data/testsets.py",
    "chars": 20186,
    "preview": "import glob\nfrom typing import List\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms\nimport torchvi"
  },
  {
    "path": "ldm/data/testsets_vqm.py",
    "chars": 16418,
    "preview": "import glob\nfrom typing import List\nfrom PIL import Image\nimport torch\nfrom torchvision import transforms\nimport torchvi"
  },
  {
    "path": "ldm/data/vfitransforms.py",
    "chars": 674,
    "preview": "import random\nimport torch\nimport torchvision\nimport torchvision.transforms.functional as TF\n\n\ndef rand_crop(*args, sz):"
  },
  {
    "path": "ldm/lr_scheduler.py",
    "chars": 3882,
    "preview": "import numpy as np\n\n\nclass LambdaWarmUpCosineScheduler:\n    \"\"\"\n    note: use with a base_lr of 1.0\n    \"\"\"\n    def __in"
  },
  {
    "path": "ldm/models/autoencoder.py",
    "chars": 15340,
    "preview": "import torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom torch.optim.lr_scheduler import LambdaL"
  },
  {
    "path": "ldm/models/diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ldm/models/diffusion/ddim.py",
    "chars": 11063,
    "preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom functools import partial\n\nfrom ldm.modu"
  },
  {
    "path": "ldm/models/diffusion/ddpm.py",
    "chars": 74715,
    "preview": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e316"
  },
  {
    "path": "ldm/modules/attention.py",
    "chars": 14950,
    "preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
  },
  {
    "path": "ldm/modules/diffusionmodules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ldm/modules/diffusionmodules/model.py",
    "chars": 23522,
    "preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom ld"
  },
  {
    "path": "ldm/modules/diffusionmodules/openaimodel.py",
    "chars": 36432,
    "preview": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np"
  },
  {
    "path": "ldm/modules/diffusionmodules/util.py",
    "chars": 9561,
    "preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
  },
  {
    "path": "ldm/modules/ema.py",
    "chars": 2982,
    "preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates="
  },
  {
    "path": "ldm/modules/losses/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "ldm/modules/losses/vqperceptual.py",
    "chars": 7838,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discrim"
  },
  {
    "path": "ldm/modules/maxvit.py",
    "chars": 11799,
    "preview": "import torch\nfrom torch import nn, einsum\nimport torch.nn.functional\nfrom einops import rearrange, repeat\nfrom einops.la"
  },
  {
    "path": "ldm/util.py",
    "chars": 5857,
    "preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
  },
  {
    "path": "main.py",
    "chars": 27803,
    "preview": "import argparse, os, sys, datetime, glob\nimport numpy as np\nimport time\nimport torch\nimport torchvision\nimport pytorch_l"
  },
  {
    "path": "metrics/flolpips/.gitignore",
    "chars": 30,
    "preview": "*__pycache__*\n*.ipynb\n*delete*"
  },
  {
    "path": "metrics/flolpips/LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2022 danielism97\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "metrics/flolpips/README.md",
    "chars": 1145,
    "preview": "# FloLPIPS: A bespoke video quality metric for frame interpoation\n\n### Duolikun Danier, Fan Zhang, David Bull\n\n\n[Project"
  },
  {
    "path": "metrics/flolpips/__init__.py",
    "chars": 24,
    "preview": "from .flolpips import *\n"
  },
  {
    "path": "metrics/flolpips/correlation/correlation.py",
    "chars": 13604,
    "preview": "#!/usr/bin/env python\n\nimport torch\n\nimport cupy\nimport re\n\nkernel_Correlation_rearrange = '''\n\textern \"C\" __global__ vo"
  },
  {
    "path": "metrics/flolpips/flolpips.py",
    "chars": 15089,
    "preview": "\nfrom __future__ import absolute_import\nimport os\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.autog"
  },
  {
    "path": "metrics/flolpips/pretrained_networks.py",
    "chars": 6507,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):"
  },
  {
    "path": "metrics/flolpips/pwcnet.py",
    "chars": 16053,
    "preview": "#!/usr/bin/env python\n\nimport torch\n\nimport getopt\nimport math\nimport numpy\nimport os\nimport PIL\nimport PIL.Image\nimport"
  },
  {
    "path": "metrics/flolpips/utils.py",
    "chars": 3495,
    "preview": "import numpy as np\nimport cv2\nimport torch\n\n\ndef normalize_tensor(in_feat,eps=1e-10):\n    norm_factor = torch.sqrt(torch"
  },
  {
    "path": "metrics/lpips/__init__.py",
    "chars": 6170,
    "preview": "\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport nu"
  },
  {
    "path": "metrics/lpips/lpips.py",
    "chars": 9390,
    "preview": "\nfrom __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.auto"
  },
  {
    "path": "metrics/lpips/pretrained_networks.py",
    "chars": 6507,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):"
  },
  {
    "path": "metrics/pytorch_ssim/__init__.py",
    "chars": 4844,
    "preview": "import torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nimport numpy as np\nfrom math import exp"
  },
  {
    "path": "setup.py",
    "chars": 233,
    "preview": "from setuptools import setup, find_packages\n\nsetup(\n    name='latent-diffusion',\n    version='0.0.1',\n    description=''"
  },
  {
    "path": "utility.py",
    "chars": 5326,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport cv2\nfrom metrics import pytorch_ssim, lpips, flol"
  }
]

About this extraction

This page contains the full source code of the danier97/LDMVFI GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 47 files (425.2 KB), approximately 111.6k tokens, and a symbol index with 577 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!