Repository: danier97/LDMVFI Branch: main Commit: eee2dc3566f2 Files: 47 Total size: 425.2 KB Directory structure: gitextract_grs9wzig/ ├── .gitignore ├── LICENSE ├── README.md ├── configs/ │ ├── autoencoder/ │ │ └── vqflow-f32.yaml │ └── ldm/ │ └── ldmvfi-vqflow-f32-c256-concat_max.yaml ├── cupy_module/ │ ├── __init__.py │ └── dsepconv.py ├── environment.yaml ├── evaluate.py ├── evaluate_vqm.py ├── interpolate_yuv.py ├── ldm/ │ ├── data/ │ │ ├── __init__.py │ │ ├── bvi_vimeo.py │ │ ├── testsets.py │ │ ├── testsets_vqm.py │ │ └── vfitransforms.py │ ├── lr_scheduler.py │ ├── models/ │ │ ├── autoencoder.py │ │ └── diffusion/ │ │ ├── __init__.py │ │ ├── ddim.py │ │ └── ddpm.py │ ├── modules/ │ │ ├── attention.py │ │ ├── diffusionmodules/ │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── ema.py │ │ ├── losses/ │ │ │ ├── __init__.py │ │ │ └── vqperceptual.py │ │ └── maxvit.py │ └── util.py ├── main.py ├── metrics/ │ ├── flolpips/ │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── correlation/ │ │ │ └── correlation.py │ │ ├── flolpips.py │ │ ├── pretrained_networks.py │ │ ├── pwcnet.py │ │ └── utils.py │ ├── lpips/ │ │ ├── __init__.py │ │ ├── lpips.py │ │ └── pretrained_networks.py │ └── pytorch_ssim/ │ └── __init__.py ├── setup.py └── utility.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pth *.ckpt *__pycache__* *.pyc *egg* *src/* *.ipynb logs/* *delete* eval_results* *.idea* *.pytorch ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 danielism97 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # LDMVFI: Video Frame Interpolation with Latent Diffusion Models [**Duolikun Danier**](https://danier97.github.io/), [**Fan Zhang**](https://fan-aaron-zhang.github.io/), [**David Bull**](https://david-bull.github.io/) [Project](TODO) | [arXiv](https://arxiv.org/abs/2303.09508) | [Video](https://drive.google.com/file/d/1oL6j_l3b2QEqsL0iO7qSZrGUXJaTpRWN/view?usp=share_link) ![Demo gif](assets/ldmvfi.gif) ## Overview We observe that most existing learning-based VFI models are trained to minimise the L1/L2/VGG loss between their outputs and the ground-truth frames. However, it was shown in previous works that these metrics do not correlate well with the **perceptual quality** of VFI. On the other hand, generative models, especially diffusion models, are showing remarkable results in generating visual content with high perceptual quality. In this work, we leverage the high-fidelity image/video generation capabilities of **latent diffusion models** to perform generative VFI.

Paper

## Dependencies and Installation See [environment.yaml](./environment.yaml) for requirements on packages. Simple installation: ``` conda env create -f environment.yaml ``` ## Pre-trained Model The pre-trained model can be downloaded from [here](https://drive.google.com/file/d/1_Xx2fBYQT9O-6O3zjzX76O9XduGnCh_7/view?usp=share_link), and its corresponding config file is [this yaml](./configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml). ## Preparing datasets ### Training sets: [[Vimeo-90K]](http://toflow.csail.mit.edu/) | [[BVI-DVC quintuplets]](https://drive.google.com/file/d/1i_CoqiNrZ2AU8DKjU8aHM1jIaDGW0fE5/view?usp=sharing) ### Test sets: [[Middlebury]](https://vision.middlebury.edu/flow/data/) | [[UCF101]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[DAVIS]](https://sites.google.com/view/xiangyuxu/qvi_nips19) | [[SNU-FILM]](https://myungsub.github.io/CAIN/) To make use of the [evaluate.py](evaluate.py) and the files in [ldm/data/](./ldm/data/), the dataset folder names should be lower-case and structured as follows. ``` └──── / ├──── middlebury_others/ | ├──── input/ | | ├──── Beanbags/ | | ├──── ... | | └──── Walking/ | └──── gt/ | ├──── Beanbags/ | ├──── ... | └──── Walking/ ├──── ucf101/ | ├──── 0/ | ├──── ... | └──── 99/ ├──── davis90/ | ├──── bear/ | ├──── ... | └──── walking/ ├──── snufilm/ | ├──── test-easy.txt | ├──── ... | └──── data/SNU-FILM/test/... ├──── bvidvc/quintuplets | ├──── 00000/ | ├──── ... | └──── 17599/ └──── vimeo_septuplet/ ├──── sequences/ ├──── sep_testlist.txt └──── sep_trainlist.txt ``` ## Evaluation To evaluate LDMVFI (with DDIM sampler), for example, on the Middlebury dataset, using PSNR/SSIM/LPIPS, run the following command. ``` python evaluate.py \ --config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \ --ckpt \ --dataset Middlebury_others \ --metrics PSNR SSIM LPIPS \ --data_dir \ --out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \ --use_ddim ``` This will create the directory `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/`, and store the interpolated frames, as well as a `results.txt` file in that directory. For other test sets, replace `Middlebury_other` with the corresponding class names defined in [ldm/data/testsets.py](ldm/data/testsets.py) (e.g. `Ucf101_triplet`). \ To evaluate the model on perceptual video metric FloLPIPS, first evaluate the image metrics using the code above (so that the interpolated frames are saved in `eval_results/ldmvfi-vqflow-f32-c256-concat_max`), then run the following code. ``` python evaluate_vqm.py \ --exp ldmvfi-vqflow-f32-c256-concat_max \ --dataset Middlebury_others \ --metrics FloLPIPS \ --data_dir \ --out_dir eval_results/ldmvfi-vqflow-f32-c256-concat_max/ \ ``` This will read the interpolated frames previously stored in `eval_results/ldmvfi-vqflow-f32-c256-concat_max/Middlebury_others/` then output the evaluation results to `results_vqm.txt` in the same folder. \ To interpolate a video (in .yuv format), use the following code. ``` python interpolate_yuv.py \ --net LDMVFI \ --config configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml \ --ckpt \ --input_yuv \ --size \ --out_fps \ --out_dir \ --use_ddim ``` ## Training LDMVFI is trained in two stages, where the VQ-FIGAN and the denoising U-Net are trained separately. ### VQ-FIGAN ``` python main.py --base configs/autoencoder/vqflow-f32.yaml -t --gpus 0, ``` ### Denoising U-Net ``` python main.py --base configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml -t --gpus 0, ``` These will create a `logs/` folder within which the corresonding directories are created for each experiment. The log files from training include checkpoints, images and tensorboard loggings. To resume from a checkpoint file, simply use the `--resume` argument in [main.py](main.py) to specify the checkpoint. ## Citation ``` @article{danier2023ldmvfi, title={LDMVFI: Video Frame Interpolation with Latent Diffusion Models}, author={Danier, Duolikun and Zhang, Fan and Bull, David}, journal={arXiv preprint arXiv:2303.09508}, year={2023} } ``` ## Acknowledgement Our code is adapted from the original [latent-diffusion](https://github.com/CompVis/latent-diffusion) repository. We thank the authors for sharing their code. ================================================ FILE: configs/autoencoder/vqflow-f32.yaml ================================================ model: base_learning_rate: 1.0e-5 target: ldm.models.autoencoder.VQFlowNet params: monitor: val/total_loss embed_dim: 3 n_embed: 8192 ddconfig: double_z: False z_channels: 3 resolution: 256 in_channels: 3 out_ch: 3 ch: 64 ch_mult: [1,2,2,2,4] # f = 2 ^ len(ch_mult) num_res_blocks: 1 cond_type: max_cross_attn attn_type: max attn_resolutions: [] dropout: 0.0 lossconfig: target: ldm.modules.losses.vqperceptual.VQLPIPSWithDiscriminator params: disc_conditional: False disc_in_channels: 3 disc_start: 10000 disc_weight: 0.8 codebook_weight: 1.0 data: target: main.DataModuleFromConfig params: batch_size: 10 num_workers: 0 wrap: false train: target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet params: db_dir: C:/data_tmp/ crop_sz: [256,256] iter: True validation: target: ldm.data.bvi_vimeo.Vimeo90k_triplet params: db_dir: C:/data_tmp/vimeo_septuplet/ train: False crop_sz: [256,256] augment_s: False augment_t: False lightning: callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 8000 val_batch_frequency: 800 max_images: 8 increase_log_steps: False log_images_kwargs: {'N': 1} trainer: benchmark: True max_epochs: -1 ================================================ FILE: configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml ================================================ model: base_learning_rate: 1.0e-06 target: ldm.models.diffusion.ddpm.LatentDiffusionVFI params: linear_start: 0.0015 linear_end: 0.0195 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image cond_stage_key: past_future_frames image_size: 8 channels: 3 cond_stage_trainable: False concat_mode: True monitor: val/loss_simple_ema unet_config: target: ldm.modules.diffusionmodules.openaimodel.UNetModel params: image_size: 8 # img size of latent, used during training, determines some model params, so don't change for inference in_channels: 9 out_channels: 3 model_channels: 256 attention_resolutions: #note: this isn\t actually the resolution but # the downsampling factor, i.e. this corresnponds to # attention on spatial resolution 8,16,32, as the # spatial reolution of the latents is 32 for f8 - 4 - 2 - 1 num_res_blocks: 2 channel_mult: - 1 - 2 - 4 num_head_channels: 32 use_max_self_attn: True # replace all full self-attention with MaxViT first_stage_config: target: ldm.models.autoencoder.VQFlowNetInterface params: ckpt_path: null # must specify pre-trained autoencoding model ckpt to train the denoising UNet embed_dim: 3 n_embed: 8192 ddconfig: double_z: False z_channels: 3 resolution: 256 in_channels: 3 out_ch: 3 ch: 64 ch_mult: [1,2,2,2,4] # f = 2 ^ len(ch_mult) num_res_blocks: 1 cond_type: max_cross_attn attn_type: max attn_resolutions: [ ] dropout: 0.0 lossconfig: target: torch.nn.Identity cond_stage_config: __is_first_stage__ data: target: main.DataModuleFromConfig params: batch_size: 64 num_workers: 0 wrap: false train: target: ldm.data.bvi_vimeo.BVI_Vimeo_triplet params: db_dir: C:/data_tmp/ crop_sz: [256,256] iter: True validation: target: ldm.data.bvi_vimeo.Vimeo90k_triplet params: db_dir: C:/data_tmp/vimeo_septuplet/ train: False crop_sz: [256,256] augment_s: False augment_t: False lightning: callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 1250 val_batch_frequency: 125 max_images: 8 increase_log_steps: False log_images_kwargs: {'N': 1} trainer: benchmark: True max_epochs: -1 ================================================ FILE: cupy_module/__init__.py ================================================ ================================================ FILE: cupy_module/dsepconv.py ================================================ import torch import cupy import re class Stream: ptr = torch.cuda.current_stream().cuda_stream # end kernel_DSepconv_updateOutput = ''' extern "C" __global__ void kernel_DSepconv_updateOutput( const int n, const float* input, const float* vertical, const float* horizontal, const float* offset_x, const float* offset_y, const float* mask, float* output ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float dblOutput = 0.0; const int intSample = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); const int intDepth = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); const int intX = ( intIndex ) % SIZE_3(output); for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1) { for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1) { float delta_x = OFFSET_4(offset_y, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intFilterY*SIZE_1(vertical) + intFilterX, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, intDepth, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, intDepth, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, intDepth, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); dblOutput += floatValue * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, SIZE_1(vertical)*intFilterY + intFilterX, intY, intX); } } output[intIndex] = dblOutput; } } ''' kernel_DSepconv_updateGradVertical = ''' extern "C" __global__ void kernel_DSepconv_updateGradVertical( const int n, const float* gradLoss, const float* input, const float* horizontal, const float* offset_x, const float* offset_y, const float* mask, float* gradVertical ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float floatOutput = 0.0; const int intSample = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical) / SIZE_1(gradVertical) ) % SIZE_0(gradVertical); const int intFilterY = ( intIndex / SIZE_3(gradVertical) / SIZE_2(gradVertical) ) % SIZE_1(gradVertical); const int intY = ( intIndex / SIZE_3(gradVertical) ) % SIZE_2(gradVertical); const int intX = ( intIndex ) % SIZE_3(gradVertical); for (int intFilterX = 0; intFilterX < SIZE_1(horizontal); intFilterX += 1){ int intDepth = intFilterY * SIZE_1(horizontal) + intFilterX; float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(horizontal) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(horizontal) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) + VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) + VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX); } gradVertical[intIndex] = floatOutput; } } ''' kernel_DSepconv_updateGradHorizontal = ''' extern "C" __global__ void kernel_DSepconv_updateGradHorizontal( const int n, const float* gradLoss, const float* input, const float* vertical, const float* offset_x, const float* offset_y, const float* mask, float* gradHorizontal ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float floatOutput = 0.0; const int intSample = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal) / SIZE_1(gradHorizontal) ) % SIZE_0(gradHorizontal); const int intFilterX = ( intIndex / SIZE_3(gradHorizontal) / SIZE_2(gradHorizontal) ) % SIZE_1(gradHorizontal); const int intY = ( intIndex / SIZE_3(gradHorizontal) ) % SIZE_2(gradHorizontal); const int intX = ( intIndex ) % SIZE_3(gradHorizontal); for (int intFilterY = 0; intFilterY < SIZE_1(vertical); intFilterY += 1){ int intDepth = intFilterY * SIZE_1(vertical) + intFilterX; float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; float floatSampled0 = VALUE_4(input, intSample, 0, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 0, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 0, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 0, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); float floatSampled1 = VALUE_4(input, intSample, 1, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 1, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 1, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 1, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); float floatSampled2 = VALUE_4(input, intSample, 2, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 2, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, 2, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, 2, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)); floatOutput += VALUE_4(gradLoss, intSample, 0, intY, intX) * floatSampled0 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) + VALUE_4(gradLoss, intSample, 1, intY, intX) * floatSampled1 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX) + VALUE_4(gradLoss, intSample, 2, intY, intX) * floatSampled2 * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX); } gradHorizontal[intIndex] = floatOutput; } } ''' kernel_DSepconv_updateGradMask = ''' extern "C" __global__ void kernel_DSepconv_updateGradMask( const int n, const float* gradLoss, const float* input, const float* vertical, const float* horizontal, const float* offset_x, const float* offset_y, float* gradMask ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float floatOutput = 0.0; const int intSample = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask) / SIZE_1(gradMask) ) % SIZE_0(gradMask); const int intDepth = ( intIndex / SIZE_3(gradMask) / SIZE_2(gradMask) ) % SIZE_1(gradMask); const int intY = ( intIndex / SIZE_3(gradMask) ) % SIZE_2(gradMask); const int intX = ( intIndex ) % SIZE_3(gradMask); int intFilterY = intDepth / SIZE_1(vertical); int intFilterX = intDepth % SIZE_1(vertical); float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; for (int intChannel = 0; intChannel < 3; intChannel++){ floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * ( VALUE_4(input, intSample, intChannel, top, left) * (1 + (left - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, intChannel, top, right) * (1 - (right - position_x)) * (1 + (top - position_y)) + VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x)) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x)) * (1 - (bottom - position_y)) ) * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX); } gradMask[intIndex] = floatOutput; } } ''' kernel_DSepconv_updateGradOffsetX = ''' extern "C" __global__ void kernel_DSepconv_updateGradOffsetX( const int n, const float* gradLoss, const float* input, const float* vertical, const float* horizontal, const float* offset_x, const float* offset_y, const float* mask, float* gradOffsetX ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float floatOutput = 0.0; const int intSample = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX); const int intDepth = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) ) % SIZE_1(gradOffsetX); const int intY = ( intIndex / SIZE_3(gradOffsetX) ) % SIZE_2(gradOffsetX); const int intX = ( intIndex ) % SIZE_3(gradOffsetX); int intFilterY = intDepth / SIZE_1(vertical); int intFilterX = intDepth % SIZE_1(vertical); float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; for (int intChannel = 0; intChannel < 3; intChannel++){ floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * ( - VALUE_4(input, intSample, intChannel, top, left) * (1 + (left - position_x)) - VALUE_4(input, intSample, intChannel, top, right) * (1 - (right - position_x)) + VALUE_4(input, intSample, intChannel, bottom, left) * (1 + (left - position_x)) + VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (right - position_x)) ) * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX); } gradOffsetX[intIndex] = floatOutput; } } ''' kernel_DSepconv_updateGradOffsetY = ''' extern "C" __global__ void kernel_DSepconv_updateGradOffsetY( const int n, const float* gradLoss, const float* input, const float* vertical, const float* horizontal, const float* offset_x, const float* offset_y, const float* mask, float* gradOffsetY ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float floatOutput = 0.0; const int intSample = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) / SIZE_1(gradOffsetX) ) % SIZE_0(gradOffsetX); const int intDepth = ( intIndex / SIZE_3(gradOffsetX) / SIZE_2(gradOffsetX) ) % SIZE_1(gradOffsetX); const int intY = ( intIndex / SIZE_3(gradOffsetX) ) % SIZE_2(gradOffsetX); const int intX = ( intIndex ) % SIZE_3(gradOffsetX); int intFilterY = intDepth / SIZE_1(vertical); int intFilterX = intDepth % SIZE_1(vertical); float delta_x = OFFSET_4(offset_y, intSample, intDepth, intY, intX); float delta_y = OFFSET_4(offset_x, intSample, intDepth, intY, intX); float position_x = delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1; float position_y = delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1; if (position_x < 0) position_x = 0; if (position_x > SIZE_3(input) - 1) position_x = SIZE_3(input) - 1; if (position_y < 0) position_y = 0; if (position_y > SIZE_2(input) - 1) position_y = SIZE_2(input) - 1; int left = floor(delta_x + intX + intFilterX - (SIZE_1(vertical) - 1) / 2 + 1); int right = left + 1; if (left < 0) left = 0; if (left > SIZE_3(input) - 1) left = SIZE_3(input) - 1; if (right < 0) right = 0; if (right > SIZE_3(input) - 1) right = SIZE_3(input) - 1; int top = floor(delta_y + intY + intFilterY - (SIZE_1(vertical) - 1) / 2 + 1); int bottom = top + 1; if (top < 0) top = 0; if (top > SIZE_2(input) - 1) top = SIZE_2(input) - 1; if (bottom < 0) bottom = 0; if (bottom > SIZE_2(input) - 1) bottom = SIZE_2(input) - 1; for (int intChannel = 0; intChannel < 3; intChannel++){ floatOutput += VALUE_4(gradLoss, intSample, intChannel, intY, intX) * ( - VALUE_4(input, intSample, intChannel, top, left) * (1 + (top - position_y)) + VALUE_4(input, intSample, intChannel, top, right) * (1 + (top - position_y)) - VALUE_4(input, intSample, intChannel, bottom, left) * (1 - (bottom - position_y)) + VALUE_4(input, intSample, intChannel, bottom, right) * (1 - (bottom - position_y)) ) * VALUE_4(vertical, intSample, intFilterY, intY, intX) * VALUE_4(horizontal, intSample, intFilterX, intY, intX) * VALUE_4(mask, intSample, intDepth, intY, intX); } gradOffsetY[intIndex] = floatOutput; } } ''' def cupy_kernel(strFunction, objectVariables): strKernel = globals()[strFunction] while True: objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objectMatch is None: break # end intArg = int(objectMatch.group(2)) strTensor = objectMatch.group(4) intSizes = objectVariables[strTensor].size() strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg])) # end while True: objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) if objectMatch is None: break # end intArgs = int(objectMatch.group(2)) strArgs = objectMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objectVariables[strTensor].stride() strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( intStrides[intArg]) + ')' for intArg in range(intArgs)] strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') # end while True: objectMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) if objectMatch is None: break # end intArgs = int(objectMatch.group(2)) strArgs = objectMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objectVariables[strTensor].stride() strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( intStrides[intArg]) + ')' for intArg in range(intArgs)] strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') # end return strKernel # end @cupy.memoize(for_each_device=True) def cupy_launch(strFunction, strKernel): # return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) return cupy.RawKernel(strKernel, strFunction) # end class _FunctionDSepconv(torch.autograd.Function): @staticmethod def forward(self, input, vertical, horizontal, offset_x, offset_y, mask): self.save_for_backward(input, vertical, horizontal, offset_x, offset_y, mask) intSample = input.size(0) intInputDepth = input.size(1) intInputHeight = input.size(2) intInputWidth = input.size(3) intFilterSize = min(vertical.size(1), horizontal.size(1)) intOutputHeight = min(vertical.size(2), horizontal.size(2)) intOutputWidth = min(vertical.size(3), horizontal.size(3)) assert (intInputHeight == intOutputHeight + intFilterSize - 1) assert (intInputWidth == intOutputWidth + intFilterSize - 1) assert (input.is_contiguous() == True) assert (vertical.is_contiguous() == True) assert (horizontal.is_contiguous() == True) assert (offset_x.is_contiguous() == True) assert (offset_y.is_contiguous() == True) assert (mask.is_contiguous() == True) output = input.new_zeros([intSample, intInputDepth, intOutputHeight, intOutputWidth]) if input.is_cuda == True: n = output.nelement() cupy_launch('kernel_DSepconv_updateOutput', cupy_kernel('kernel_DSepconv_updateOutput', { 'input': input, 'vertical': vertical, 'horizontal': horizontal, 'offset_x': offset_x, 'offset_y': offset_y, 'mask': mask, 'output': output }))( grid=tuple([int((n + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[n, input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), mask.data_ptr(), output.data_ptr()], stream=Stream ) elif input.is_cuda == False: raise NotImplementedError() # end return output # end @staticmethod def backward(self, gradOutput): input, vertical, horizontal, offset_x, offset_y, mask = self.saved_tensors intSample = input.size(0) intInputDepth = input.size(1) intInputHeight = input.size(2) intInputWidth = input.size(3) intFilterSize = min(vertical.size(1), horizontal.size(1)) intOutputHeight = min(vertical.size(2), horizontal.size(2)) intOutputWidth = min(vertical.size(3), horizontal.size(3)) assert (intInputHeight == intOutputHeight + intFilterSize - 1) assert (intInputWidth == intOutputWidth + intFilterSize - 1) assert (gradOutput.is_contiguous() == True) gradInput = input.new_zeros([intSample, intInputDepth, intInputHeight, intInputWidth]) if \ self.needs_input_grad[0] == True else None gradVertical = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \ self.needs_input_grad[1] == True else None gradHorizontal = input.new_zeros([intSample, intFilterSize, intOutputHeight, intOutputWidth]) if \ self.needs_input_grad[2] == True else None gradOffsetX = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \ self.needs_input_grad[3] == True else None gradOffsetY = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \ self.needs_input_grad[4] == True else None gradMask = input.new_zeros([intSample, intFilterSize * intFilterSize, intOutputHeight, intOutputWidth]) if \ self.needs_input_grad[5] == True else None if input.is_cuda == True: nv = gradVertical.nelement() cupy_launch('kernel_DSepconv_updateGradVertical', cupy_kernel('kernel_DSepconv_updateGradVertical', { 'gradLoss': gradOutput, 'input': input, 'horizontal': horizontal, 'offset_x': offset_x, 'offset_y': offset_y, 'mask': mask, 'gradVertical': gradVertical }))( grid=tuple([int((nv + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[nv, gradOutput.data_ptr(), input.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), mask.data_ptr(), gradVertical.data_ptr()], stream=Stream ) nh = gradHorizontal.nelement() cupy_launch('kernel_DSepconv_updateGradHorizontal', cupy_kernel('kernel_DSepconv_updateGradHorizontal', { 'gradLoss': gradOutput, 'input': input, 'vertical': vertical, 'offset_x': offset_x, 'offset_y': offset_y, 'mask': mask, 'gradHorizontal': gradHorizontal }))( grid=tuple([int((nh + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[nh, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), mask.data_ptr(), gradHorizontal.data_ptr()], stream=Stream ) nx = gradOffsetX.nelement() cupy_launch('kernel_DSepconv_updateGradOffsetX', cupy_kernel('kernel_DSepconv_updateGradOffsetX', { 'gradLoss': gradOutput, 'input': input, 'vertical': vertical, 'horizontal': horizontal, 'offset_x': offset_x, 'offset_y': offset_y, 'mask': mask, 'gradOffsetX': gradOffsetX }))( grid=tuple([int((nx + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[nx, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), mask.data_ptr(), gradOffsetX.data_ptr()], stream=Stream ) ny = gradOffsetY.nelement() cupy_launch('kernel_DSepconv_updateGradOffsetY', cupy_kernel('kernel_DSepconv_updateGradOffsetY', { 'gradLoss': gradOutput, 'input': input, 'vertical': vertical, 'horizontal': horizontal, 'offset_x': offset_x, 'offset_y': offset_y, 'mask': mask, 'gradOffsetX': gradOffsetY }))( grid=tuple([int((ny + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[ny, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), mask.data_ptr(), gradOffsetY.data_ptr()], stream=Stream ) nm = gradMask.nelement() cupy_launch('kernel_DSepconv_updateGradMask', cupy_kernel('kernel_DSepconv_updateGradMask', { 'gradLoss': gradOutput, 'input': input, 'vertical': vertical, 'horizontal': horizontal, 'offset_x': offset_x, 'offset_y': offset_y, 'gradMask': gradMask }))( grid=tuple([int((nm + 512 - 1) / 512), 1, 1]), block=tuple([512, 1, 1]), args=[nm, gradOutput.data_ptr(), input.data_ptr(), vertical.data_ptr(), horizontal.data_ptr(), offset_x.data_ptr(), offset_y.data_ptr(), gradMask.data_ptr()], stream=Stream ) elif input.is_cuda == False: raise NotImplementedError() # end return gradInput, gradVertical, gradHorizontal, gradOffsetX, gradOffsetY, gradMask # end # end def FunctionDSepconv(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask): return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask) # end class ModuleDSepconv(torch.nn.Module): def __init__(self): super(ModuleDSepconv, self).__init__() # end def forward(self, tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask): return _FunctionDSepconv.apply(tensorInput, tensorVertical, tensorHorizontal, tensorOffsetX, tensorOffsetY, tensorMask) # end # end # float floatValue = VALUE_4(input, intSample, intDepth, top, left) * (1 - (delta_x - floor(delta_x))) * (1 - (delta_y - floor(delta_y))) + # VALUE_4(input, intSample, intDepth, top, right) * (delta_x - floor(delta_x)) * (1 - (delta_y - floor(delta_y))) + # VALUE_4(input, intSample, intDepth, bottom, left) * (1 - (delta_x - floor(delta_x))) * (delta_y - floor(delta_y)) + # VALUE_4(input, intSample, intDepth, bottom, right) * (delta_x - floor(delta_x)) * (delta_y - floor(delta_y)); ================================================ FILE: environment.yaml ================================================ name: ldmvfi channels: - pytorch - defaults - conda-forge dependencies: - python=3.9.13 - pytorch=1.11.0 - torchvision=0.12.0 - cudatoolkit=11.3 - pip: - opencv-python==4.6.0.66 - pudb==2022.1.3 - imageio==2.22.3 - imageio-ffmpeg==0.4.7 - pytorch-lightning==1.7.7 - omegaconf==2.2.3 - test-tube==0.7.5 - streamlit==1.14.0 - einops==0.5.0 - torch-fidelity==0.3.0 - transformers==4.23.1 - timm==0.6.12 - cupy - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e . # conda create -n ldmvfi python=3.9 # conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch # pip install opencv-python==4.6.0.66 pudb==2022.1.3 imageio==2.22.3 imageio-ffmpeg==0.4.7 pytorch-lightning==1.7.7 omegaconf==2.2.3 test-tube==0.7.5 streamlit==1.14.0 einops==0.5.0 torch-fidelity==0.3.0 transformers==4.23.1 # pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers # pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip # pip install -e . # pip install timm ================================================ FILE: evaluate.py ================================================ import argparse import os import torch from functools import partial from omegaconf import OmegaConf from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.data import testsets parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation') parser.add_argument('--config', type=str, default=None) parser.add_argument('--ckpt', type=str, default=None) parser.add_argument('--dataset', type=str, default='Middlebury_others') parser.add_argument('--metrics', nargs='+', type=str, default=['PSNR', 'SSIM', 'LPIPS']) parser.add_argument('--data_dir', type=str, default='D:\\') parser.add_argument('--out_dir', type=str, default='eval_results') parser.add_argument('--resume', dest='resume', default=False, action='store_true') # sampler args parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true') parser.add_argument('--ddim_eta', type=float, default=1.0) parser.add_argument('--ddim_steps', type=int, default=200) def main(): args = parser.parse_args() # initialise model config = OmegaConf.load(args.config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(args.ckpt)['state_dict']) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) model = model.eval() print('Model loaded successfully') # set up sampler if args.use_ddim: ddim = DDIMSampler(model) sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False) else: sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False) # setup output dirs if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) # initialise test set print('Testing on dataset: ', args.dataset) test_dir = os.path.join(args.out_dir, args.dataset) if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']: db_folder = args.dataset.split('_')[0].lower() else: db_folder = args.dataset.lower() test_db = getattr(testsets, args.dataset)(os.path.join(args.data_dir, db_folder)) if not os.path.exists(test_dir): os.mkdir(test_dir) test_db.eval(model, sample_func, metrics=args.metrics, output_dir=test_dir, resume=args.resume) if __name__ == '__main__': main() ================================================ FILE: evaluate_vqm.py ================================================ import argparse import os from ldm.data import testsets_vqm parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation') parser.add_argument('--exp', type=str, default=None) parser.add_argument('--dataset', type=str, default='Middlebury_others') parser.add_argument('--metrics', nargs='+', type=str, default=['FloLPIPS']) parser.add_argument('--data_dir', type=str, default='D:\\') parser.add_argument('--out_dir', type=str, default='eval_results') parser.add_argument('--resume', dest='resume', default=False, action='store_true') def main(): args = parser.parse_args() # initialise model model = args.exp print('Evaluating model:', model) # setup output dirs assert os.path.exists(args.out_dir), 'Frames not previously interpolated!' # initialise test set print('Testing on dataset: ', args.dataset) test_dir = os.path.join(args.out_dir, args.dataset) assert os.path.exists(test_dir), f'{args.dataset} not pre-computed!' if args.dataset.split('_')[0] in ['VFITex', 'Ucf101', 'Davis90']: db_folder = args.dataset.split('_')[0].lower() else: db_folder = args.dataset.lower() test_db = getattr(testsets_vqm, args.dataset)(os.path.join(args.data_dir, db_folder)) test_db.eval(metrics=args.metrics, output_dir=test_dir, resume=args.resume) if __name__ == '__main__': main() ================================================ FILE: interpolate_yuv.py ================================================ import argparse import torch import torchvision.transforms.functional as TF import os from PIL import Image from tqdm import tqdm import skvideo.io from functools import partial from utility import read_frame_yuv2rgb, tensor2rgb from omegaconf import OmegaConf from main import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler parser = argparse.ArgumentParser(description='Frame Interpolation Evaluation') parser.add_argument('--net', type=str, default='LDMVFI') parser.add_argument('--config', type=str, default='configs/ldm/ldmvfi-vqflow-f32-c256-concat_max.yaml') parser.add_argument('--ckpt', type=str, default='ckpt.pth') parser.add_argument('--input_yuv', type=str, default='D:\\') parser.add_argument('--size', type=str, default='1920x1080') parser.add_argument('--out_fps', type=int, default=60) parser.add_argument('--out_dir', type=str, default='.') # sampler args parser.add_argument('--use_ddim', dest='use_ddim', default=False, action='store_true') parser.add_argument('--ddim_eta', type=float, default=1.0) parser.add_argument('--ddim_steps', type=int, default=200) def main(): args = parser.parse_args() # initialise model config = OmegaConf.load(args.config) model = instantiate_from_config(config.model) model.load_state_dict(torch.load(args.ckpt)['state_dict']) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) model = model.eval() print('Model loaded successfully') # set up sampler if args.use_ddim: ddim = DDIMSampler(model) sample_func = partial(ddim.sample, S=args.ddim_steps, eta=args.ddim_eta, verbose=False) else: sample_func = partial(model.sample_ddpm, return_intermediates=False, verbose=False) # Setup output file if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) _, fname = os.path.split(args.input_yuv) seq_name = fname.strip('.yuv') width, height = args.size.split('x') bit_depth = 16 if '16bit' in fname else 10 if '10bit' in fname else 8 pix_fmt = '444' if '444' in fname else '420' try: width = int(width) height = int(height) except: print('Invalid size, should be \'x\'') return outname = '{}_{}x{}_{}fps_{}.mp4'.format(seq_name, width, height, args.out_fps, args.net) writer = skvideo.io.FFmpegWriter(os.path.join(args.out_dir, outname), inputdict={ '-r': str(args.out_fps) }, outputdict={ '-pix_fmt': 'yuv420p', '-s': '{}x{}'.format(width,height), '-r': str(args.out_fps), '-vcodec': 'libx264', #use the h.264 codec '-crf': '0', #set the constant rate factor to 0, which is lossless '-preset':'veryslow' #the slower the better compression, in princple, try #other options see https://trac.ffmpeg.org/wiki/Encode/H.264 } ) # Start interpolation print('Using model {} to upsample file {}'.format(args.net, fname)) stream = open(args.input_yuv, 'r') file_size = os.path.getsize(args.input_yuv) # YUV reading setup bytes_per_frame = width*height*1.5 if pix_fmt == '444': bytes_per_frame *= 2 if bit_depth != 8: bytes_per_frame *= 2 num_frames = int(file_size // bytes_per_frame) rawFrame0 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, 0, bit_depth, pix_fmt)) frame0 = TF.normalize(TF.to_tensor(rawFrame0), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda() for t in tqdm(range(1, num_frames)): rawFrame1 = Image.fromarray(read_frame_yuv2rgb(stream, width, height, t, bit_depth, pix_fmt)) frame1 = TF.normalize(TF.to_tensor(rawFrame1), (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[None,...].cuda() with torch.no_grad(): with model.ema_scope(): # form condition tensor and define shape of latent rep xc = {'prev_frame': frame0, 'next_frame': frame1} c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc) shape = (model.channels, c.shape[2], c.shape[3]) # run sampling and get denoised latent out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape) if isinstance(out, tuple): # using ddim out = out[0] # reconstruct interpolated frame from latent out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list) out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1] # write to output video writer.writeFrame(tensor2rgb(frame0)[0]) writer.writeFrame(tensor2rgb(out)[0]) # update frame0 frame0 = frame1 # write the last frame writer.writeFrame(tensor2rgb(frame1)[0]) stream.close() writer.close() # close the writer if __name__ == "__main__": main() ================================================ FILE: ldm/data/__init__.py ================================================ ================================================ FILE: ldm/data/bvi_vimeo.py ================================================ import numpy as np import random from os import listdir from os.path import join, isdir, split, getsize from torch.utils.data import Dataset import torchvision.transforms.functional as TF from PIL import Image import ldm.data.vfitransforms as vt from functools import partial class Vimeo90k_triplet(Dataset): def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=True, augment_t=True): seq_dir = join(db_dir, 'sequences') self.crop_sz = crop_sz self.augment_s = augment_s self.augment_t = augment_t if train: seq_list_txt = join(db_dir, 'sep_trainlist.txt') else: seq_list_txt = join(db_dir, 'sep_testlist.txt') with open(seq_list_txt) as f: contents = f.readlines() seq_path = [line.strip() for line in contents if line != '\n'] self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path] def __getitem__(self, index): rawFrame3 = Image.open(join(self.seq_path_list[index], "im3.png")) rawFrame4 = Image.open(join(self.seq_path_list[index], "im4.png")) rawFrame5 = Image.open(join(self.seq_path_list[index], "im5.png")) if self.crop_sz is not None: rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz) if self.augment_s: rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5) if self.augment_t: rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5) to_array = partial(np.array, dtype=np.float32) frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255 frame3 = frame3/127.5 - 1.0 frame4 = frame4/127.5 - 1.0 frame5 = frame5/127.5 - 1.0 return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5} def __len__(self): return len(self.seq_path_list) class Vimeo90k_quintuplet(Dataset): def __init__(self, db_dir, train=True, crop_sz=(256,256), augment_s=True, augment_t=True): seq_dir = join(db_dir, 'sequences') self.crop_sz = crop_sz self.augment_s = augment_s self.augment_t = augment_t if train: seq_list_txt = join(db_dir, 'sep_trainlist.txt') else: seq_list_txt = join(db_dir, 'sep_testlist.txt') with open(seq_list_txt) as f: contents = f.readlines() seq_path = [line.strip() for line in contents if line != '\n'] self.seq_path_list = [join(seq_dir, *line.split('/')) for line in seq_path] def __getitem__(self, index): rawFrame1 = Image.open(join(self.seq_path_list[index], "im1.png")) rawFrame3 = Image.open(join(self.seq_path_list[index], "im3.png")) rawFrame4 = Image.open(join(self.seq_path_list[index], "im4.png")) rawFrame5 = Image.open(join(self.seq_path_list[index], "im5.png")) rawFrame7 = Image.open(join(self.seq_path_list[index], "im7.png")) if self.crop_sz is not None: rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_crop(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, sz=self.crop_sz) if self.augment_s: rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5) if self.augment_t: rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5) frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7)) return frame1, frame3, frame4, frame5, frame7 def __len__(self): return len(self.seq_path_list) class BVIDVC_triplet(Dataset): def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True): db_dir = join(db_dir, 'quintuplets') self.crop_sz = crop_sz self.augment_s = augment_s self.augment_t = augment_t self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)] def __getitem__(self, index): cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png')) rawFrame3 = cat.crop((256, 0, 256*2, 256)) rawFrame5 = cat.crop((256*2, 0, 256*3, 256)) rawFrame4 = cat.crop((256*4, 0, 256*5, 256)) if self.crop_sz is not None: rawFrame3, rawFrame4, rawFrame5 = vt.rand_crop(rawFrame3, rawFrame4, rawFrame5, sz=self.crop_sz) if self.augment_s: rawFrame3, rawFrame4, rawFrame5 = vt.rand_flip(rawFrame3, rawFrame4, rawFrame5, p=0.5) if self.augment_t: rawFrame3, rawFrame4, rawFrame5 = vt.rand_reverse(rawFrame3, rawFrame4, rawFrame5, p=0.5) to_array = partial(np.array, dtype=np.float32) frame3, frame4, frame5 = map(to_array, (rawFrame3, rawFrame4, rawFrame5)) #(256,256,3), 0-255 frame3 = frame3/127.5 - 1.0 frame4 = frame4/127.5 - 1.0 frame5 = frame5/127.5 - 1.0 return {'image': frame4, 'prev_frame': frame3, 'next_frame': frame5} def __len__(self): return len(self.seq_path_list) class BVIDVC_quintuplet(Dataset): def __init__(self, db_dir, res=None, crop_sz=(256,256), augment_s=True, augment_t=True): db_dir = join(db_dir, 'quintuplets') self.crop_sz = crop_sz self.augment_s = augment_s self.augment_t = augment_t self.seq_path_list = [join(db_dir, f) for f in listdir(db_dir)] def __getitem__(self, index): cat = Image.open(join(self.seq_path_list[index], 'quintuplet.png')) rawFrame1 = cat.crop((0, 0, 256, 256)) rawFrame3 = cat.crop((256, 0, 256*2, 256)) rawFrame5 = cat.crop((256*2, 0, 256*3, 256)) rawFrame7 = cat.crop((256*3, 0, 256*4, 256)) rawFrame4 = cat.crop((256*4, 0, 256*5, 256)) if self.augment_s: rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_flip(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5) if self.augment_t: rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7 = vt.rand_reverse(rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7, p=0.5) frame1, frame3, frame4, frame5, frame7 = map(TF.to_tensor, (rawFrame1, rawFrame3, rawFrame4, rawFrame5, rawFrame7)) return frame1, frame3, frame4, frame5, frame7 def __len__(self): return len(self.seq_path_list) class Sampler(Dataset): def __init__(self, datasets, p_datasets=None, iter=False, samples_per_epoch=1000): self.datasets = datasets self.len_datasets = np.array([len(dataset) for dataset in self.datasets]) self.p_datasets = p_datasets self.iter = iter if p_datasets is None: self.p_datasets = self.len_datasets / np.sum(self.len_datasets) self.samples_per_epoch = samples_per_epoch self.accum = [0,] for i, length in enumerate(self.len_datasets): self.accum.append(self.accum[-1] + self.len_datasets[i]) def __getitem__(self, index): if self.iter: # iterate through all datasets for i in range(len(self.accum)): if index < self.accum[i]: return self.datasets[i-1].__getitem__(index-self.accum[i-1]) else: # first sample a dataset dataset = random.choices(self.datasets, self.p_datasets)[0] # sample a sequence from the dataset return dataset.__getitem__(random.randint(0,len(dataset)-1)) def __len__(self): if self.iter: return int(np.sum(self.len_datasets)) else: return self.samples_per_epoch class BVI_Vimeo_triplet(Dataset): def __init__(self, db_dir, crop_sz=[256,256], p_datasets=None, iter=False, samples_per_epoch=1000): vimeo90k_train = Vimeo90k_triplet(join(db_dir, 'vimeo_septuplet'), train=True, crop_sz=crop_sz) bvidvc_train = BVIDVC_triplet(join(db_dir, 'bvidvc'), crop_sz=crop_sz) self.datasets = [vimeo90k_train, bvidvc_train] self.len_datasets = np.array([len(dataset) for dataset in self.datasets]) self.p_datasets = p_datasets self.iter = iter if p_datasets is None: self.p_datasets = self.len_datasets / np.sum(self.len_datasets) self.samples_per_epoch = samples_per_epoch self.accum = [0,] for i, length in enumerate(self.len_datasets): self.accum.append(self.accum[-1] + self.len_datasets[i]) def __getitem__(self, index): if self.iter: # iterate through all datasets for i in range(len(self.accum)): if index < self.accum[i]: return self.datasets[i-1].__getitem__(index-self.accum[i-1]) else: # first sample a dataset dataset = random.choices(self.datasets, self.p_datasets)[0] # sample a sequence from the dataset return dataset.__getitem__(random.randint(0,len(dataset)-1)) def __len__(self): if self.iter: return int(np.sum(self.len_datasets)) else: return self.samples_per_epoch ================================================ FILE: ldm/data/testsets.py ================================================ import glob from typing import List from PIL import Image import torch from torchvision import transforms import torchvision.transforms.functional as TF from torchvision.utils import save_image as imwrite import os from os.path import join, exists import utility import numpy as np import ast import time from ldm.models.autoencoder import * class TripletTestSet: def __init__(self): self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False): results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results.txt'), 'a') for idx in range(len(self.im_list)): if resume and idx < start_idx: assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!' continue print(f'Evaluating {self.im_list[idx]}') t0 = time.time() if not exists(join(output_dir, self.im_list[idx])): os.makedirs(join(output_dir, self.im_list[idx])) with torch.no_grad(): with model.ema_scope(): # form condition tensor and define shape of latent rep xc = {'prev_frame': self.input0_list[idx], 'next_frame': self.input1_list[idx]} c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc) shape = (model.channels, c.shape[2], c.shape[3]) # run sampling and get denoised latent rep out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None) if isinstance(out, tuple): # using ddim out = out[0] # reconstruct interpolated frame from latent out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list) out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1] gt = self.gt_list[idx] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input0_list[idx], self.input1_list[idx]])[0].item() results_dict[metric].append(score) imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True) msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n' print(msg, end='') logfile.write(msg) logfile.close() class Middlebury_others(TripletTestSet): def __init__(self, db_dir): super(Middlebury_others, self).__init__() self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1] self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1] self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0)) class Davis(TripletTestSet): def __init__(self, db_dir): super(Davis, self).__init__() self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn'] self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0)) class Ucf(TripletTestSet): def __init__(self, db_dir): super(Ucf, self).__init__() self.im_list = os.listdir(db_dir) self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0)) class Snufilm(TripletTestSet): def __init__(self, db_dir, mode): super(Snufilm, self).__init__() self.mode = mode self.input0_list = [] self.input1_list = [] self.gt_list = [] with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f: triplet_list = f.read().splitlines() self.im_list = [] for i, triplet in enumerate(triplet_list, 1): self.im_list.append('{}-{}'.format(mode, str(i).zfill(3))) lst = triplet.split(' ') self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0)) class Snufilm_easy(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-5] super(Snufilm_easy, self).__init__(db_dir, 'easy') class Snufilm_medium(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-7] super(Snufilm_medium, self).__init__(db_dir, 'medium') class Snufilm_hard(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-5] super(Snufilm_hard, self).__init__(db_dir, 'hard') class Snufilm_extreme(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-8] super(Snufilm_extreme, self).__init__(db_dir, 'extreme') class VFITex_triplet: def __init__(self, db_dir): self.seq_list = os.listdir(db_dir) self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False): model.eval() results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results.txt'), 'a') for idx, seq in enumerate(self.seq_list): if resume and idx < start_idx: assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!' continue print(f'Evaluating {seq}') t0 = time.time() seqpath = join(self.db_dir, seq) if not exists(join(output_dir, seq)): os.makedirs(join(output_dir, seq)) # interpolate between every 2 frames gt_list, out_list, inputs_list = [], [], [] tmp_dict = {k : [] for k in metrics} num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')]) for t in range(1, num_frames-5, 2): im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png')) im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png')) im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png')) # center crop if 4K if '4K' in seq: w, h = im0.size im0 = TF.center_crop(im0, (h//2, w//2)) im1 = TF.center_crop(im1, (h//2, w//2)) im2 = TF.center_crop(im2, (h//2, w//2)) im0 = self.transform(im0).cuda().unsqueeze(0) im1 = self.transform(im1).cuda().unsqueeze(0) im2 = self.transform(im2).cuda().unsqueeze(0) with torch.no_grad(): with model.ema_scope(): # form condition tensor and define shape of latent rep xc = {'prev_frame': im0, 'next_frame': im2} c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc) shape = (model.channels, c.shape[2], c.shape[3]) # run sampling and get denoised latent rep out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None) if isinstance(out, tuple): # using ddim out = out[0] # reconstruct interpolated frame from latent out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list) out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(im1, out, [im0, im2])[0].item() tmp_dict[metric].append(score) imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True) # compute sequence-level scores for metric in metrics: results_dict[metric].append(np.mean(tmp_dict[metric])) # log msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n' print(msg, end='') logfile.write(msg) logfile.close() class Davis90_triplet: def __init__(self, db_dir): self.seq_list = sorted(os.listdir(db_dir)) self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name=None, resume=False): model.eval() results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results.txt'), 'a') for idx, seq in enumerate(self.seq_list): if resume and idx < start_idx: assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!' continue print(f'Evaluating {seq}') t0 = time.time() seqpath = join(self.db_dir, seq) if not exists(join(output_dir, seq)): os.makedirs(join(output_dir, seq)) # interpolate between every 2 frames gt_list, out_list, inputs_list = [], [], [] tmp_dict = {k : [] for k in metrics} num_frames = len(os.listdir(seqpath)) for t in range(0, num_frames-6, 2): im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg')) im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg')) im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg')) im3 = self.transform(im3).cuda().unsqueeze(0) im4 = self.transform(im4).cuda().unsqueeze(0) im5 = self.transform(im5).cuda().unsqueeze(0) with torch.no_grad(): with model.ema_scope(): # form condition tensor and define shape of latent rep xc = {'prev_frame': im3, 'next_frame': im5} c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc) shape = (model.channels, c.shape[2], c.shape[3]) # run sampling and get denoised latent rep out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None) if isinstance(out, tuple): # using ddim out = out[0] # reconstruct interpolated frame from latent out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list) out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(im4, out, [im3, im5])[0].item() tmp_dict[metric].append(score) imwrite(out, join(output_dir, seq, 'frame{}.png'.format(t+3)), value_range=(-1, 1), normalize=True) # compute sequence-level scores for metric in metrics: results_dict[metric].append(np.mean(tmp_dict[metric])) # log msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n' print(msg, end='') logfile.write(msg) logfile.close() class Ucf101_triplet: def __init__(self, db_dir): self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] self.im_list = os.listdir(db_dir) self.input3_list = [] self.input5_list = [] self.gt_list = [] for item in self.im_list: self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0)) self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0)) def eval(self, model, sample_func, metrics=['PSNR', 'SSIM'], output_dir=None, output_name='output.png', resume=False): model.eval() results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results.txt'), 'a') for idx in range(len(self.im_list)): if resume and idx < start_idx: assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!' continue print(f'Evaluating {self.im_list[idx]}') t0 = time.time() if not exists(join(output_dir, self.im_list[idx])): os.makedirs(join(output_dir, self.im_list[idx])) with torch.no_grad(): with model.ema_scope(): # form condition tensor and define shape of latent rep xc = {'prev_frame': self.input3_list[idx], 'next_frame': self.input5_list[idx]} c, phi_prev_list, phi_next_list = model.get_learned_conditioning(xc) shape = (model.channels, c.shape[2], c.shape[3]) # run sampling and get denoised latent rep out = sample_func(conditioning=c, batch_size=c.shape[0], shape=shape, x_T=None) if isinstance(out, tuple): # using ddim out = out[0] # reconstruct interpolated frame from latent out = model.decode_first_stage(out, xc, phi_prev_list, phi_next_list) out = torch.clamp(out, min=-1., max=1.) # interpolated frame in [-1,1] gt = self.gt_list[idx] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt, out, [self.input3_list[idx], self.input5_list[idx]])[0].item() results_dict[metric].append(score) imwrite(out, join(output_dir, self.im_list[idx], output_name), value_range=(-1, 1), normalize=True) msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n' print(msg, end='') logfile.write(msg) logfile.close() ================================================ FILE: ldm/data/testsets_vqm.py ================================================ import glob from typing import List from PIL import Image import torch from torchvision import transforms import torchvision.transforms.functional as TF from torchvision.utils import save_image as imwrite import os from os.path import join, exists import utility import numpy as np import ast import time class TripletTestSet: def __init__(self): self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False): results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'No res file found to resume from!' with open(join(output_dir, 'results_vqm.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results_vqm.txt'), 'a') for idx in range(len(self.im_list)): if resume and idx < start_idx: assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!' continue print(f'Evaluating {self.im_list[idx]}') t0 = time.time() assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}' out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0) gt = self.gt_list[idx] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input0_list[idx], self.input1_list[idx]]) results_dict[metric].append(score) msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n' print(msg, end='') logfile.write(msg) logfile.close() class Middlebury_others(TripletTestSet): def __init__(self, db_dir): super(Middlebury_others, self).__init__() self.im_list = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1] self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) # [1,3,H,W] in [-1,1] self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0)) class Davis(TripletTestSet): def __init__(self, db_dir): super(Davis, self).__init__() self.im_list = ['bike-trial', 'boxing', 'burnout', 'choreography', 'demolition', 'dive-in', 'dolphins', 'e-bike', 'grass-chopper', 'hurdles', 'inflatable', 'juggle', 'kart-turn', 'kids-turning', 'lions', 'mbike-santa', 'monkeys', 'ocean-birds', 'pole-vault', 'running', 'selfie', 'skydive', 'speed-skating', 'swing-boy', 'tackle', 'turtle', 'varanus-tree', 'vietnam', 'wings-turn'] self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame10.png'))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, 'input', item , 'frame11.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, 'gt', item , 'frame10i11.png'))).cuda().unsqueeze(0)) class Ucf(TripletTestSet): def __init__(self, db_dir): super(Ucf, self).__init__() self.im_list = os.listdir(db_dir) self.input0_list = [] self.input1_list = [] self.gt_list = [] for item in self.im_list: self.input0_list.append(self.transform(Image.open(join(db_dir, item , 'frame_00.png'))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, item , 'frame_02.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'frame_01_gt.png'))).cuda().unsqueeze(0)) class Snufilm(TripletTestSet): def __init__(self, db_dir, mode): super(Snufilm, self).__init__() self.mode = mode self.input0_list = [] self.input1_list = [] self.gt_list = [] with open(join(db_dir, 'test-{}.txt'.format(mode)), 'r') as f: triplet_list = f.read().splitlines() self.im_list = [] for i, triplet in enumerate(triplet_list, 1): self.im_list.append('{}-{}'.format(mode, str(i).zfill(3))) lst = triplet.split(' ') self.input0_list.append(self.transform(Image.open(join(db_dir, lst[0]))).cuda().unsqueeze(0)) self.input1_list.append(self.transform(Image.open(join(db_dir, lst[2]))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, lst[1]))).cuda().unsqueeze(0)) class Snufilm_easy(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-5] super(Snufilm_easy, self).__init__(db_dir, 'easy') class Snufilm_medium(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-7] super(Snufilm_medium, self).__init__(db_dir, 'medium') class Snufilm_hard(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-5] super(Snufilm_hard, self).__init__(db_dir, 'hard') class Snufilm_extreme(Snufilm): def __init__(self, db_dir): db_dir = db_dir[:-8] super(Snufilm_extreme, self).__init__(db_dir, 'extreme') class VFITex_triplet: def __init__(self, db_dir): self.seq_list = os.listdir(db_dir) self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False): results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results_vqm.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results_vqm.txt'), 'a') for idx, seq in enumerate(self.seq_list): if resume and idx < start_idx: assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!' continue print(f'Evaluating {seq}') t0 = time.time() seqpath = join(self.db_dir, seq) assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}' # interpolate between every 2 frames gt_list, out_list, inputs_list = [], [], [] num_frames = len([f for f in os.listdir(seqpath) if f.endswith('.png')]) for t in range(1, num_frames-5, 2): im0 = Image.open(join(seqpath, str(t+2).zfill(3)+'.png')) im1 = Image.open(join(seqpath, str(t+3).zfill(3)+'.png')) im2 = Image.open(join(seqpath, str(t+4).zfill(3)+'.png')) # center crop if 4K if '4K' in seq: w, h = im0.size im0 = TF.center_crop(im0, (h//2, w//2)) im1 = TF.center_crop(im1, (h//2, w//2)) im2 = TF.center_crop(im2, (h//2, w//2)) im0 = self.transform(im0).cuda().unsqueeze(0) im1 = self.transform(im1).cuda().unsqueeze(0) im2 = self.transform(im2).cuda().unsqueeze(0) out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0) gt_list.append(im1) out_list.append(out) if t == 1: inputs_list.append(im0) inputs_list.append(im2) # compute sequence-level scores for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list) results_dict[metric].append(score) # log msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n' print(msg, end='') logfile.write(msg) logfile.close() class Davis90_triplet: def __init__(self, db_dir): self.seq_list = sorted(os.listdir(db_dir)) self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name=None, resume=False): results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results_vqm.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results_vqm.txt'), 'a') for idx, seq in enumerate(self.seq_list): if resume and idx < start_idx: assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'skipping idx {idx} but output not found!' continue print(f'Evaluating {seq}') t0 = time.time() seqpath = join(self.db_dir, seq) assert len(glob.glob(join(output_dir, seq, '*.png'))) > 0, f'No interpolated frames found for {seq}' # interpolate between every 2 frames gt_list, out_list, inputs_list = [], [], [] num_frames = len(os.listdir(seqpath)) for t in range(0, num_frames-6, 2): im3 = Image.open(join(seqpath, str(t+2).zfill(5)+'.jpg')) im4 = Image.open(join(seqpath, str(t+3).zfill(5)+'.jpg')) im5 = Image.open(join(seqpath, str(t+4).zfill(5)+'.jpg')) im3 = self.transform(im3).cuda().unsqueeze(0) im4 = self.transform(im4).cuda().unsqueeze(0) im5 = self.transform(im5).cuda().unsqueeze(0) out = self.transform(Image.open(join(output_dir, seq, 'frame{}.png'.format(t+3)))).cuda().unsqueeze(0) gt_list.append(im4) out_list.append(out) if t == 0: inputs_list.append(im3) inputs_list.append(im5) # compute sequence-level scores for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))(gt_list, out_list, inputs_list) results_dict[metric].append(score) # log msg = '{:<15s} -- {}'.format(seq, {k: round(results_dict[k][-1], 3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]), 3) for k in metrics}) + '\n' print(msg, end='') logfile.write(msg) logfile.close() class Ucf101_triplet: def __init__(self, db_dir): self.db_dir = db_dir self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #outptu tensor in [-1,1] self.im_list = os.listdir(db_dir) self.input3_list = [] self.input5_list = [] self.gt_list = [] for item in self.im_list: self.input3_list.append(self.transform(Image.open(join(db_dir, item , 'frame1.png'))).cuda().unsqueeze(0)) self.input5_list.append(self.transform(Image.open(join(db_dir, item , 'frame2.png'))).cuda().unsqueeze(0)) self.gt_list.append(self.transform(Image.open(join(db_dir, item , 'framet.png'))).cuda().unsqueeze(0)) def eval(self, metrics=['FloLPIPS'], output_dir=None, output_name='output.png', resume=False): results_dict = {k : [] for k in metrics} start_idx = 0 if resume: # fill in results_dict with prev results and find where to start from assert os.path.exists(join(output_dir, 'results_vqm.txt')), 'no res file found to resume from!' with open(join(output_dir, 'results_vqm.txt'), 'r') as f: prev_lines = f.readlines() for line in prev_lines: if len(line) < 2: continue cur_res = ast.literal_eval(line.strip().split('-- ')[1].split('time')[0]) #parse dict from string for k in metrics: results_dict[k].append(float(cur_res[k])) start_idx += 1 logfile = open(join(output_dir, 'results_vqm.txt'), 'a') for idx in range(len(self.im_list)): if resume and idx < start_idx: assert os.path.exists(join(output_dir, self.im_list[idx], output_name)), f'skipping idx {idx} but output not found!' continue print(f'Evaluating {self.im_list[idx]}') t0 = time.time() assert exists(join(output_dir, self.im_list[idx], output_name)), f'No interpolated frames found for {self.im_list[idx]}' out = self.transform(Image.open(join(output_dir, self.im_list[idx], output_name))).cuda().unsqueeze(0) gt = self.gt_list[idx] for metric in metrics: score = getattr(utility, 'calc_{}'.format(metric.lower()))([gt], [out], [self.input3_list[idx], self.input5_list[idx]]) results_dict[metric].append(score) msg = '{:<15s} -- {}'.format(self.im_list[idx], {k: round(results_dict[k][-1],3) for k in metrics}) + f' time taken: {round(time.time()-t0,2)}' + '\n' print(msg, end='') logfile.write(msg) msg = '{:<15s} -- {}'.format('Average', {k: round(np.mean(results_dict[k]),3) for k in metrics}) + '\n\n' print(msg, end='') logfile.write(msg) logfile.close() ================================================ FILE: ldm/data/vfitransforms.py ================================================ import random import torch import torchvision import torchvision.transforms.functional as TF def rand_crop(*args, sz): i, j, h, w = torchvision.transforms.RandomCrop.get_params(args[0], output_size=sz) out = [] for im in args: out.append(TF.crop(im, i, j, h, w)) return out def rand_flip(*args, p): out = list(args) if random.random() < p: for i, im in enumerate(out): out[i] = TF.hflip(im) if random.random() < p: for i, im in enumerate(out): out[i] = TF.vflip(im) return out def rand_reverse(*args, p): if random.random() < p: return args[::-1] else: return args ================================================ FILE: ldm/lr_scheduler.py ================================================ import numpy as np class LambdaWarmUpCosineScheduler: """ note: use with a base_lr of 1.0 """ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): self.lr_warm_up_steps = warm_up_steps self.lr_start = lr_start self.lr_min = lr_min self.lr_max = lr_max self.lr_max_decay_steps = max_decay_steps self.last_lr = 0. self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start self.last_lr = lr return lr else: t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) t = min(t, 1.0) lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 1 + np.cos(t * np.pi)) self.last_lr = lr return lr def __call__(self, n, **kwargs): return self.schedule(n,**kwargs) class LambdaWarmUpCosineScheduler2: """ supports repeated iterations, configurable via lists note: use with a base_lr of 1.0. """ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) self.lr_warm_up_steps = warm_up_steps self.f_start = f_start self.f_min = f_min self.f_max = f_max self.cycle_lengths = cycle_lengths self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) self.last_f = 0. self.verbosity_interval = verbosity_interval def find_in_interval(self, n): interval = 0 for cl in self.cum_cycles[1:]: if n <= cl: return interval interval += 1 def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) t = min(t, 1.0) f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 1 + np.cos(t * np.pi)) self.last_f = f return f def __call__(self, n, **kwargs): return self.schedule(n, **kwargs) class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] if self.verbosity_interval > 0: if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] self.last_f = f return f else: f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) self.last_f = f return f ================================================ FILE: ldm/models/autoencoder.py ================================================ import torch import pytorch_lightning as pl import torch.nn.functional as F from torch.optim.lr_scheduler import LambdaLR import numpy as np from packaging import version from ldm.modules.ema import LitEma from contextlib import contextmanager from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import * from ldm.util import instantiate_from_config class VQFlowNet(pl.LightningModule): def __init__(self, ddconfig, lossconfig, n_embed, embed_dim, ckpt_path=None, ignore_keys=[], image_key="image", colorize_nlabels=None, monitor=None, batch_resize_range=None, scheduler_config=None, lr_g_factor=1.0, remap=None, sane_index_shape=False, # tell vector quantizer to return indices as bhw use_ema=False ): super().__init__() self.embed_dim = embed_dim # 3 self.n_embed = n_embed # 8192 self.image_key = image_key # 'image' self.encoder = FlowEncoder(**ddconfig) self.decoder = FlowDecoderWithResidual(**ddconfig) self.loss = instantiate_from_config(lossconfig) self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, remap=remap, sane_index_shape=sane_index_shape) self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) if colorize_nlabels is not None: assert type(colorize_nlabels)==int self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) if monitor is not None: self.monitor = monitor self.batch_resize_range = batch_resize_range if self.batch_resize_range is not None: print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor self.h0 = None self.w0 = None self.h_padded = None self.w_padded = None self.pad_h = 0 self.pad_w = 0 @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: print(f"{context}: Restored training weights") def init_from_ckpt(self, path, ignore_keys=list()): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self) def encode(self, x, ret_feature=False): ''' Set ret_feature = True when encoding conditions in ddpm ''' # Pad the input first so its size is deividable by 8. # this is to tolerate different f values, various size inputs, # and some operations in the DDPM unet model. self.h0, self.w0 = x.shape[2:] # 8: window size for max vit # 2**(nr-1): f # 4: factor of downsampling in DDPM unet min_side = 8 * 2**(self.encoder.num_resolutions-1) * 4 if self.h0 % min_side != 0: pad_h = min_side - (self.h0 % min_side) if pad_h == self.h0: # this is to avoid padding 256 patches pad_h = 0 x = F.pad(x, (0, 0, 0, pad_h), mode='reflect') self.h_padded = True self.pad_h = pad_h if self.w0 % min_side != 0: pad_w = min_side - (self.w0 % min_side) if pad_w == self.w0: pad_w = 0 x = F.pad(x, (0, pad_w, 0, 0), mode='reflect') self.w_padded = True self.pad_w = pad_w phi_list = None if ret_feature: h, phi_list = self.encoder(x, ret_feature) else: h = self.encoder(x) h = self.quant_conv(h) quant, emb_loss, info = self.quantize(h) if ret_feature: return quant, emb_loss, info, phi_list return quant, emb_loss, info def encode_to_prequant(self, x): h = self.encoder(x) h = self.quant_conv(h) return h def decode(self, quant, x_prev, x_next): cond_dict = dict( phi_prev_list = self.encode(x_prev, ret_feature=True)[-1], phi_next_list = self.encode(x_next, ret_feature=True)[-1], frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'), frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect') ) quant = self.post_quant_conv(quant) dec = self.decoder(quant, cond_dict) # check if image is padded and return the original part only if self.h_padded: dec = dec[:, :, 0:self.h0, :] if self.w_padded: dec = dec[:, :, :, 0:self.w0] return dec def decode_code(self, code_b): quant_b = self.quantize.embed_code(code_b) dec = self.decode(quant_b) return dec def forward(self, input, x_prev, x_next, return_pred_indices=False): quant, diff, (_,_,ind) = self.encode(input) dec = self.decode(quant, x_prev, x_next) if return_pred_indices: return dec, diff, ind return dec, diff def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() if self.batch_resize_range is not None: lower_size = self.batch_resize_range[0] upper_size = self.batch_resize_range[1] if self.global_step <= 4: # do the first few batches with max size to avoid later oom new_resize = upper_size else: new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) if new_resize != x.shape[2]: x = F.interpolate(x, size=new_resize, mode="bicubic") x = x.detach() return x def training_step(self, batch, batch_idx, optimizer_idx): # https://github.com/pytorch/pytorch/issues/37142 # try not to fool the heuristics x = self.get_input(batch, self.image_key) x_prev = self.get_input(batch, 'prev_frame') x_next = self.get_input(batch, 'next_frame') xrec, qloss = self(x, x_prev, x_next) if optimizer_idx == 0: # autoencode aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) return aeloss if optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): x = self.get_input(batch, self.image_key) x_prev = self.get_input(batch, 'prev_frame') x_next = self.get_input(batch, 'next_frame') xrec, qloss = self(x, x_prev, x_next) aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, last_layer=self.get_last_layer(), split="val"+suffix, ) discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, last_layer=self.get_last_layer(), split="val"+suffix, ) rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] self.log(f"val{suffix}/rec_loss", rec_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) self.log(f"val{suffix}/aeloss", aeloss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) if version.parse(pl.__version__) >= version.parse('1.4.0'): del log_dict_ae[f"val{suffix}/rec_loss"] self.log_dict(log_dict_ae) self.log_dict(log_dict_disc) return self.log_dict def configure_optimizers(self): lr_d = self.learning_rate lr_g = self.lr_g_factor*self.learning_rate print("lr_d", lr_d) print("lr_g", lr_g) opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ list(self.decoder.parameters())+ list(self.quantize.parameters())+ list(self.quant_conv.parameters())+ list(self.post_quant_conv.parameters()), lr=lr_g, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9)) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, { 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }, ] return [opt_ae, opt_disc], scheduler return [opt_ae, opt_disc], [] def get_last_layer(self): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): log = dict() x = self.get_input(batch, self.image_key) x_prev = self.get_input(batch, 'prev_frame') x_next = self.get_input(batch, 'next_frame') x = x.to(self.device) if only_inputs: log["inputs"] = x return log xrec, _ = self(x, x_prev, x_next) if x.shape[1] > 3: # colorize with random projection assert xrec.shape[1] > 3 x = self.to_rgb(x) xrec = self.to_rgb(xrec) log["inputs"] = x log["reconstructions"] = xrec if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x, x_prev, x_next) if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log def to_rgb(self, x): assert self.image_key == "segmentation" if not hasattr(self, "colorize"): self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) x = F.conv2d(x, weight=self.colorize) x = 2.*(x-x.min())/(x.max()-x.min()) - 1. return x class VQFlowNetInterface(VQFlowNet): def __init__(self, **kwargs): super().__init__(**kwargs) def encode(self, x, ret_feature=False): ''' Set ret_feature = True when encoding conditions in ddpm ''' # Pad the input first so its size is deividable by 8. # this is to tolerate different f values, various size inputs, # and some operations in the DDPM unet model. self.h0, self.w0 = x.shape[2:] # 8: window size for max vit # 2**(nr-1): f # 4: factor of downsampling in DDPM unet min_side = 512#8 * 2**(self.encoder.num_resolutions-1) * 16 min_side = min_side // 2 if self.h0 <= 256 else min_side if self.h0 % min_side != 0: pad_h = min_side - (self.h0 % min_side) if pad_h == self.h0: # this is to avoid padding 256 patches pad_h = 0 x = F.pad(x, (0, 0, 0, pad_h), mode='reflect') self.h_padded = True self.pad_h = pad_h if self.w0 % min_side != 0: pad_w = min_side - (self.w0 % min_side) if pad_w == self.w0: pad_w = 0 x = F.pad(x, (0, pad_w, 0, 0), mode='reflect') self.w_padded = True self.pad_w = pad_w phi_list = None if ret_feature: h, phi_list = self.encoder(x, ret_feature) else: h = self.encoder(x) h = self.quant_conv(h) if ret_feature: return h, phi_list return h def decode(self, h, x_prev, x_next, phi_prev_list, phi_next_list, force_not_quantize=False): # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) else: quant = h cond_dict = dict( phi_prev_list = phi_prev_list, phi_next_list = phi_next_list, frame_prev = F.pad(x_prev, (0, self.pad_w, 0, self.pad_h), mode='reflect'), frame_next = F.pad(x_next, (0, self.pad_w, 0, self.pad_h), mode='reflect') ) quant = self.post_quant_conv(quant) dec = self.decoder(quant, cond_dict) # check if image is padded and return the original part only if self.h_padded: dec = dec[:, :, 0:self.h0, :] if self.w_padded: dec = dec[:, :, :, 0:self.w0] return dec ================================================ FILE: ldm/models/diffusion/__init__.py ================================================ ================================================ FILE: ldm/models/diffusion/ddim.py ================================================ """SAMPLING ONLY.""" import torch import numpy as np from tqdm import tqdm from functools import partial from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like class DDIMSampler(object): def __init__(self, model, schedule="linear", **kwargs): super().__init__() self.model = model self.ddpm_num_timesteps = model.num_timesteps self.schedule = schedule def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != torch.device("cuda"): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) alphas_cumprod = self.model.alphas_cumprod assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) self.register_buffer('betas', to_torch(self.model.betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) # ddim sampling parameters ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), ddim_timesteps=self.ddim_timesteps, eta=ddim_eta,verbose=verbose) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) @torch.no_grad() def sample(self, S, batch_size, shape, conditioning=None, callback=None, normals_sequence=None, img_callback=None, quantize_x0=False, eta=0., mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, verbose=True, x_T=None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs ): if conditioning is not None: if isinstance(conditioning, dict): cbs = conditioning[list(conditioning.keys())[0]].shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) # sampling C, H, W = shape size = (batch_size, C, H, W) if verbose: print(f'Data shape for DDIM sampling is {size}, eta {eta}') samples, intermediates = self.ddim_sampling(conditioning, size, callback=callback, img_callback=img_callback, quantize_denoised=quantize_x0, mask=mask, x0=x0, ddim_use_original_steps=False, noise_dropout=noise_dropout, temperature=temperature, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, x_T=x_T, log_every_t=log_every_t, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, verbose=verbose ) return samples, intermediates @torch.no_grad() def ddim_sampling(self, cond, shape, x_T=None, ddim_use_original_steps=False, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None,verbose=True): device = self.model.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T if timesteps is None: timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps elif timesteps is not None and not ddim_use_original_steps: subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 timesteps = self.ddim_timesteps[:subset_end] intermediates = {'x_inter': [img], 'pred_x0': [img]} time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] if verbose: print(f"Running DDIM Sampling with {total_steps} timesteps") iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) if verbose else time_range for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) if mask is not None: assert x0 is not None img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? img = img_orig * mask + (1. - mask) * img outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, quantize_denoised=quantize_denoised, temperature=temperature, noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) img, pred_x0 = outs if callback: callback(i) if img_callback: img_callback(pred_x0, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) intermediates['pred_x0'].append(pred_x0) return img, intermediates @torch.no_grad() def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None): b, *_, device = *x.shape, x.device if unconditional_conditioning is None or unconditional_guidance_scale == 1.: e_t = self.model.apply_model(x, t, c) else: x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([unconditional_conditioning, c]) e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) if score_corrector is not None: assert self.model.parameterization == "eps" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas # select parameters corresponding to the currently considered timestep a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) # current prediction for x_0 pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() if quantize_denoised: pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) # direction pointing to x_t dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise return x_prev, pred_x0 ================================================ FILE: ldm/models/diffusion/ddpm.py ================================================ """ wild mixture of https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py https://github.com/CompVis/taming-transformers -- merci """ import torch import torch.nn as nn import numpy as np import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR from einops import rearrange, repeat from contextlib import contextmanager from functools import partial from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config from ldm.modules.ema import LitEma from ldm.models.autoencoder import * from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like from ldm.models.diffusion.ddim import DDIMSampler __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self class DDPM(pl.LightningModule): # classic DDPM with Gaussian diffusion, in image space def __init__(self, unet_config, timesteps=1000, beta_schedule="linear", loss_type="l2", ckpt_path=None, ignore_keys=[], load_only_unet=False, monitor="val/loss", use_ema=True, first_stage_key="image", image_size=256, channels=3, log_every_t=100, clip_denoised=True, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, given_betas=None, original_elbo_weight=0., v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta l_simple_weight=1., conditioning_key=None, parameterization="eps", # all assuming fixed variance schedules scheduler_config=None, use_positional_encodings=False, learn_logvar=False, logvar_init=0., ): super().__init__() assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' self.parameterization = parameterization print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") self.cond_stage_model = None self.clip_denoised = clip_denoised self.log_every_t = log_every_t self.first_stage_key = first_stage_key self.image_size = image_size # try conv? self.channels = channels self.use_positional_encodings = use_positional_encodings self.model = DiffusionWrapper(unet_config, conditioning_key) count_params(self.model, verbose=True) self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model) print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.use_scheduler = scheduler_config is not None if self.use_scheduler: self.scheduler_config = scheduler_config self.v_posterior = v_posterior self.original_elbo_weight = original_elbo_weight self.l_simple_weight = l_simple_weight if monitor is not None: self.monitor = monitor if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) self.loss_type = loss_type self.learn_logvar = learn_logvar self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) if self.learn_logvar: self.logvar = nn.Parameter(self.logvar, requires_grad=True) def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if exists(given_betas): betas = given_betas else: betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) if self.parameterization == "eps": lvlb_weights = self.betas ** 2 / ( 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) else: raise NotImplementedError("mu not supported") # TODO how to choose this term lvlb_weights[0] = lvlb_weights[1] self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) assert not torch.isnan(self.lvlb_weights).all() @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: for ik in ignore_keys: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") if len(missing) > 0: print(f"Missing Keys: {missing}") if len(unexpected) > 0: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): """ Get the distribution q(x_t | x_0). :param x_start: the [N x C x ...] tensor of noiseless inputs. :param t: the number of diffusion steps (minus 1). Here, 0 means one step. :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) return mean, variance, log_variance def predict_start_from_noise(self, x_t, t, noise): return ( extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) def q_posterior(self, x_start, x_t, t): posterior_mean = ( extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance(self, x, t, clip_denoised: bool): model_out = self.model(x, t) if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out if clip_denoised: x_recon.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): b, *_, device = *x.shape, x.device model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def p_sample_loop(self, shape, return_intermediates=False): device = self.betas.device b = shape[0] img = torch.randn(shape, device=device) intermediates = [img] for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), clip_denoised=self.clip_denoised) if i % self.log_every_t == 0 or i == self.num_timesteps - 1: intermediates.append(img) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, batch_size=16, return_intermediates=False): image_size = self.image_size channels = self.channels return self.p_sample_loop((batch_size, channels, image_size, image_size), return_intermediates=return_intermediates) def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) def get_loss(self, pred, target, mean=True): if self.loss_type == 'l1': loss = (target - pred).abs() if mean: loss = loss.mean() elif self.loss_type == 'l2': if mean: loss = torch.nn.functional.mse_loss(target, pred) else: loss = torch.nn.functional.mse_loss(target, pred, reduction='none') else: raise NotImplementedError("unknown loss type '{loss_type}'") return loss def p_losses(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_out = self.model(x_noisy, t) loss_dict = {} if self.parameterization == "eps": target = noise elif self.parameterization == "x0": target = x_start else: raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) log_prefix = 'train' if self.training else 'val' loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) loss_simple = loss.mean() * self.l_simple_weight loss_vlb = (self.lvlb_weights[t] * loss).mean() loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) loss = loss_simple + self.original_elbo_weight * loss_vlb loss_dict.update({f'{log_prefix}/loss': loss}) return loss, loss_dict def forward(self, x, *args, **kwargs): # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() return self.p_losses(x, t, *args, **kwargs) def get_input(self, batch, k): x = batch[k] if len(x.shape) == 3: x = x[..., None] x = rearrange(x, 'b h w c -> b c h w') x = x.to(memory_format=torch.contiguous_format).float() return x def shared_step(self, batch): x = self.get_input(batch, self.first_stage_key) loss, loss_dict = self(x) return loss, loss_dict def training_step(self, batch, batch_idx): loss, loss_dict = self.shared_step(batch) self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log("global_step", float(self.global_step), prog_bar=True, logger=True, on_step=True, on_epoch=False) if self.use_scheduler: lr = self.optimizers().param_groups[0]['lr'] self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @torch.no_grad() def validation_step(self, batch, batch_idx): _, loss_dict_no_ema = self.shared_step(batch) with self.ema_scope(): _, loss_dict_ema = self.shared_step(batch) loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) def on_train_batch_end(self, *args, **kwargs): if self.use_ema: self.model_ema(self.model) def _get_rows_from_list(self, samples): n_imgs_per_row = len(samples) denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): log = dict() x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) x = x.to(self.device)[:N] log["inputs"] = x # get diffusion row diffusion_row = list() x_start = x[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), '1 -> b', b=n_row) t = t.to(self.device).long() noise = torch.randn_like(x_start) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) diffusion_row.append(x_noisy) log["diffusion_row"] = self._get_rows_from_list(diffusion_row) if sample: # get denoise row with self.ema_scope("Plotting"): samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) log["samples"] = samples log["denoise_row"] = self._get_rows_from_list(denoise_row) if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.learn_logvar: params = params + [self.logvar] opt = torch.optim.AdamW(params, lr=lr) return opt class LatentDiffusion(DDPM): """main class""" def __init__(self, first_stage_config, cond_stage_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, *args, **kwargs): self.num_timesteps_cond = default(num_timesteps_cond, 1) assert self.num_timesteps_cond <= kwargs['timesteps'] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__': conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 self.scale_factor = scale_factor self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[:self.num_timesteps_cond] = ids def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__": print(f"Training {self.__class__.__name__} as an unconditional model.") self.cond_stage_model = None # self.be_unconditional = True else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): param.requires_grad = False else: assert config != '__is_first_stage__' assert config != '__is_unconditional__' model = instantiate_from_config(config) self.cond_stage_model = model def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append(self.decode_first_stage(zd.to(self.device), force_not_quantize=force_no_decoder_quantization)) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") return self.scale_factor * z def get_learned_conditioning(self, c): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): c = self.cond_stage_model.encode(c) else: c = self.cond_stage_model(c) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) return c @torch.no_grad() def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, cond_key=None, return_original_cond=False, bs=None): x = super().get_input(batch, k) if bs is not None: x = x[:bs] x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() if self.model.conditioning_key is not None: if cond_key is None: cond_key = self.cond_stage_key if cond_key != self.first_stage_key: if cond_key in ['caption', 'coordinates_bbox']: xc = batch[cond_key] elif cond_key == 'class_label': xc = batch else: xc = super().get_input(batch, cond_key).to(self.device) else: xc = x if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): # import pudb; pudb.set_trace() c = self.get_learned_conditioning(xc) else: c = self.get_learned_conditioning(xc.to(self.device)) else: c = xc if bs is not None: c = c[:bs] if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) ckey = __conditioning_keys__[self.model.conditioning_key] c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} else: c = None xc = None if self.use_positional_encodings: pos_x, pos_y = self.compute_latent_shifts(batch) c = {'pos_x': pos_x, 'pos_y': pos_y} out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z) out.extend([x, xrec]) if return_original_cond: out.append(xc) return out @torch.no_grad() def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1. / self.scale_factor * z return self.first_stage_model.decode(z) @torch.no_grad() def encode_first_stage(self, x): return self.first_stage_model.encode(x) def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) loss = self(x, c) return loss def forward(self, x, c, *args, **kwargs): t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() if self.model.conditioning_key is not None: assert c is not None if self.cond_stage_trainable: c = self.get_learned_conditioning(c) if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): # hybrid case, cond is exptected to be a dict pass else: if not isinstance(cond, list): cond = [cond] key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def p_losses(self, x_start, cond, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} prefix = 'train' if self.training else 'val' if self.parameterization == "x0": target = x_start elif self.parameterization == "eps": target = noise else: raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) loss_dict.update({'logvar': self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) loss += (self.original_elbo_weight * loss_vlb) loss_dict.update({f'{prefix}/loss': loss}) return loss, loss_dict def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, return_x0=False, score_corrector=None, corrector_kwargs=None): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) if score_corrector is not None: assert self.parameterization == "eps" model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) if return_codebook_ids: model_out, logits = model_out if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out else: raise NotImplementedError() if clip_denoised: x_recon.clamp_(-1., 1.) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) if return_codebook_ids: return model_mean, posterior_variance, posterior_log_variance, logits elif return_x0: return model_mean, posterior_variance, posterior_log_variance, x_recon else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return_codebook_ids=False, quantize_denoised=False, return_x0=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs else: model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) # if return_codebook_ids: # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) if return_x0: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 else: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, log_every_t=None): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps if batch_size is not None: b = batch_size if batch_size is not None else shape[0] shape = [batch_size] + list(shape) else: b = batch_size = shape[0] if x_T is None: img = torch.randn(shape, device=self.device) else: img = x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', total=timesteps) if verbose else reversed( range(0, timesteps)) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img, x0_partial = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, return_x0=True, temperature=temperature[i], noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1. - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) if callback: callback(i) if img_callback: img_callback(img, i) return img, intermediates @torch.no_grad() def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps if start_T is not None: timesteps = min(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( range(0, timesteps)) if mask is not None: assert x0 is not None assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1. - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) if callback: callback(i) if img_callback: img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None,**kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0) @torch.no_grad() def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, shape,cond,verbose=False,**kwargs) else: samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True,**kwargs) return samples, intermediates @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, **kwargs): use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, bs=N) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x log["reconstruction"] = xrec if self.model.conditioning_key is not None: if hasattr(self.cond_stage_model, "decode"): xc = self.cond_stage_model.decode(c) log["conditioning"] = xc elif self.cond_stage_key in ["caption"]: xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) log["conditioning"] = xc elif self.cond_stage_key == 'class_label': xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) log['conditioning'] = xc elif isimage(xc): log["conditioning"] = xc if ismap(xc): log["original_conditioning"] = self.to_rgb(xc) if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), '1 -> b', b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with self.ema_scope("Plotting"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples) log["samples"] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row) log["denoise_row"] = denoise_grid if quantize_denoised: # also display when quantizing x0 while sampling with self.ema_scope("Plotting Quantized Denoised"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta, quantize_denoised=True) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_x0_quantized"] = x_samples if inpaint: # make a simple center square b, h, w = z.shape[0], z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. mask = mask[:, None, ...] with self.ema_scope("Plotting Inpaint"): samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_inpainting"] = x_samples log["mask"] = mask # outpaint with self.ema_scope("Plotting Outpaint"): samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, ddim_steps=ddim_steps, x0=z[:N], mask=mask) x_samples = self.decode_first_stage(samples.to(self.device)) log["samples_outpainting"] = x_samples if plot_progressive_rows: with self.ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising(c, shape=(self.channels, self.image_size, self.image_size), batch_size=N) prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") log["progressive_row"] = prog_row if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.cond_stage_trainable: print(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [opt], scheduler return opt @torch.no_grad() def to_rgb(self, x): x = x.float() if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x class DiffusionWrapper(pl.LightningModule): def __init__(self, diff_model_config, conditioning_key): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) self.conditioning_key = conditioning_key assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'mcvd'] def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): if self.conditioning_key is None: out = self.diffusion_model(x, t) elif self.conditioning_key == 'concat': xc = torch.cat([x] + c_concat, dim=1) out = self.diffusion_model(xc, t) elif self.conditioning_key == 'crossattn': cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(x, t, context=cc) elif self.conditioning_key == 'hybrid': xc = torch.cat([x] + c_concat, dim=1) cc = torch.cat(c_crossattn, 1) out = self.diffusion_model(xc, t, context=cc) elif self.conditioning_key == 'adm': cc = c_crossattn[0] out = self.diffusion_model(x, t, y=cc) elif self.conditioning_key == 'mcvd': out = self.diffusion_model(x, t, cond=c_concat[0]) else: raise NotImplementedError() return out ######################################################### ######################################################### ######################################################### ######################################################### ######################################################### ######################################################### class LatentDiffusionVFI(DDPM): """main class""" def __init__(self, first_stage_config, cond_stage_config, num_timesteps_cond=None, cond_stage_key="image", cond_stage_trainable=False, concat_mode=True, cond_stage_forward=None, conditioning_key=None, scale_factor=1.0, *args, **kwargs): self.num_timesteps_cond = default(num_timesteps_cond, 1) assert self.num_timesteps_cond <= kwargs['timesteps'] # for backwards compatibility after implementation of DiffusionWrapper if conditioning_key is None: conditioning_key = 'concat' if concat_mode else 'crossattn' if cond_stage_config == '__is_unconditional__': conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) super().__init__(conditioning_key=conditioning_key, *args, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 except: self.num_downs = 0 self.scale_factor = scale_factor self.instantiate_first_stage(first_stage_config) self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False self.restarted_from_ckpt = False if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys) self.restarted_from_ckpt = True def make_cond_schedule(self, ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[:self.num_timesteps_cond] = ids def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) self.shorten_cond_schedule = self.num_timesteps_cond > 1 if self.shorten_cond_schedule: self.make_cond_schedule() def instantiate_first_stage(self, config): model = instantiate_from_config(config) self.first_stage_model = model.eval() self.first_stage_model.train = disabled_train for param in self.first_stage_model.parameters(): param.requires_grad = False def instantiate_cond_stage(self, config): if not self.cond_stage_trainable: if config == "__is_first_stage__": print("Using first stage also as cond stage.") self.cond_stage_model = self.first_stage_model elif config == "__is_unconditional__": print(f"Training {self.__class__.__name__} as an unconditional model.") self.cond_stage_model = None # self.be_unconditional = True else: model = instantiate_from_config(config) self.cond_stage_model = model.eval() self.cond_stage_model.train = disabled_train for param in self.cond_stage_model.parameters(): param.requires_grad = False else: assert config != '__is_first_stage__' assert config != '__is_unconditional__' model = instantiate_from_config(config) self.cond_stage_model = model def _get_denoise_row_from_list(self, samples, xc=None, phi_prev_list=None, phi_next_list=None, desc='', force_no_decoder_quantization=False): denoise_row = [] for zd in tqdm(samples, desc=desc): denoise_row.append(self.decode_first_stage(zd.to(self.device), xc, phi_prev_list, phi_next_list, force_not_quantize=force_no_decoder_quantization)) n_imgs_per_row = len(denoise_row) denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid def get_first_stage_encoding(self, encoder_posterior): if isinstance(encoder_posterior, torch.Tensor): z = encoder_posterior else: raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") return self.scale_factor * z def get_learned_conditioning(self, c): phi_prev_list, phi_next_list = None, None if isinstance(c, dict) and 'prev_frame' in c.keys(): c_prev, phi_prev_list = self.cond_stage_model.encode(c['prev_frame'], ret_feature=True) c_next, phi_next_list = self.cond_stage_model.encode(c['next_frame'], ret_feature=True) c = torch.cat([c_prev, c_next], dim=1) else: c = self.cond_stage_model.encode(c) return c, phi_prev_list, phi_next_list @torch.no_grad() def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, cond_key=None, return_original_cond=False, return_phi=False, bs=None): x = super().get_input(batch, k) if bs is not None: x = x[:bs] x = x.to(self.device) encoder_posterior = self.encode_first_stage(x) z = self.get_first_stage_encoding(encoder_posterior).detach() phi_prev_list, phi_next_list = None, None assert self.model.conditioning_key is not None if cond_key == None: cond_key = self.cond_stage_key assert cond_key == 'past_future_frames' xc = {'prev_frame': super().get_input(batch, 'prev_frame'), 'next_frame': super().get_input(batch, 'next_frame')} if not self.cond_stage_trainable or force_c_encode: if isinstance(xc, dict) or isinstance(xc, list): # import pudb; pudb.set_trace() c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc) else: c, phi_prev_list, phi_next_list = self.get_learned_conditioning(xc.to(self.device)) else: c = xc if bs is not None: c = c[:bs] if isinstance(xc, dict): xc['prev_frame'] = xc['prev_frame'][:bs] xc['next_frame'] = xc['next_frame'][:bs] if phi_prev_list and phi_next_list: phi_prev_list = [phi_prev[:bs] for phi_prev in phi_prev_list] phi_next_list = [phi_next[:bs] for phi_next in phi_next_list] out = [z, c] if return_first_stage_outputs: xrec = self.decode_first_stage(z, xc, phi_prev_list, phi_next_list) out.extend([x, xrec]) if return_original_cond: out.append(xc) if return_phi: out.append(phi_prev_list) out.append(phi_next_list) return out @torch.no_grad() def decode_first_stage(self, z, xc=None, phi_prev_list=None, phi_next_list=None, predict_cids=False, force_not_quantize=False): if predict_cids: if z.dim() == 4: z = torch.argmax(z.exp(), dim=1).long() z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) z = rearrange(z, 'b h w c -> b c h w').contiguous() z = 1. / self.scale_factor * z return self.first_stage_model.decode(z, xc['prev_frame'], xc['next_frame'], phi_prev_list, phi_next_list, force_not_quantize=predict_cids or force_not_quantize) @torch.no_grad() def encode_first_stage(self, x): return self.first_stage_model.encode(x) def shared_step(self, batch, **kwargs): x, c = self.get_input(batch, self.first_stage_key) loss = self(x, c) return loss def forward(self, x, c, *args, **kwargs): t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() if self.model.conditioning_key is not None: assert c is not None if self.cond_stage_trainable: c, _, _ = self.get_learned_conditioning(c) if self.shorten_cond_schedule: # TODO: drop this option tc = self.cond_ids[t].to(self.device) c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): # hybrid case, cond is exptected to be a dict pass else: if not isinstance(cond, list): cond = [cond] key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' cond = {key: cond} x_recon = self.model(x_noisy, t, **cond) if isinstance(x_recon, tuple) and not return_ids: return x_recon[0] else: return x_recon def p_losses(self, x_start, cond, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) model_output = self.apply_model(x_noisy, t, cond) loss_dict = {} prefix = 'train' if self.training else 'val' if self.parameterization == "x0": target = x_start elif self.parameterization == "eps": target = noise else: raise NotImplementedError() loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) logvar_t = self.logvar[t].to(self.device) loss = loss_simple / torch.exp(logvar_t) + logvar_t # loss = loss_simple / torch.exp(self.logvar) + self.logvar if self.learn_logvar: loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) loss_dict.update({'logvar': self.logvar.data.mean()}) loss = self.l_simple_weight * loss.mean() loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) loss += (self.original_elbo_weight * loss_vlb) loss_dict.update({f'{prefix}/loss': loss}) return loss, loss_dict def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, return_x0=False, score_corrector=None, corrector_kwargs=None): t_in = t model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) if score_corrector is not None: assert self.parameterization == "eps" model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) if return_codebook_ids: model_out, logits = model_out if self.parameterization == "eps": x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) elif self.parameterization == "x0": x_recon = model_out else: raise NotImplementedError() if clip_denoised: x_recon.clamp_(-1., 1.) if quantize_denoised: x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) if return_codebook_ids: return model_mean, posterior_variance, posterior_log_variance, logits elif return_x0: return model_mean, posterior_variance, posterior_log_variance, x_recon else: return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, return_codebook_ids=False, quantize_denoised=False, return_x0=False, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): b, *_, device = *x.shape, x.device outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_codebook_ids=return_codebook_ids, quantize_denoised=quantize_denoised, return_x0=return_x0, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) if return_codebook_ids: raise DeprecationWarning("Support dropped.") model_mean, _, model_log_variance, logits = outputs elif return_x0: model_mean, _, model_log_variance, x0 = outputs else: model_mean, _, model_log_variance = outputs noise = noise_like(x.shape, device, repeat_noise) * temperature if noise_dropout > 0.: noise = torch.nn.functional.dropout(noise, p=noise_dropout) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) # if return_codebook_ids: # return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) if return_x0: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 else: return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, log_every_t=None): if not log_every_t: log_every_t = self.log_every_t timesteps = self.num_timesteps if batch_size is not None: b = batch_size if batch_size is not None else shape[0] shape = [batch_size] + list(shape) else: b = batch_size = shape[0] if x_T is None: img = torch.randn(shape, device=self.device) else: img = x_T intermediates = [] if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] if start_T is not None: timesteps = min(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', total=timesteps) if verbose else reversed( range(0, timesteps)) if type(temperature) == float: temperature = [temperature] * timesteps for i in iterator: ts = torch.full((b,), i, device=self.device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img, x0_partial = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised, return_x0=True, temperature=temperature[i], noise_dropout=noise_dropout, score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) if mask is not None: assert x0 is not None img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1. - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) if callback: callback(i) if img_callback: img_callback(img, i) return img, intermediates @torch.no_grad() def p_sample_loop(self, cond, shape, return_intermediates=False, x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, start_T=None, log_every_t=None): if not log_every_t: log_every_t = self.log_every_t device = self.betas.device b = shape[0] if x_T is None: img = torch.randn(shape, device=device) else: img = x_T intermediates = [img] if timesteps is None: timesteps = self.num_timesteps if start_T is not None: timesteps = min(timesteps, start_T) iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( range(0, timesteps)) if mask is not None: assert x0 is not None assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match for i in iterator: ts = torch.full((b,), i, device=device, dtype=torch.long) if self.shorten_cond_schedule: assert self.model.conditioning_key != 'hybrid' tc = self.cond_ids[ts].to(cond.device) cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) img = self.p_sample(img, cond, ts, clip_denoised=self.clip_denoised, quantize_denoised=quantize_denoised) if mask is not None: img_orig = self.q_sample(x0, ts) img = img_orig * mask + (1. - mask) * img if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) if callback: callback(i) if img_callback: img_callback(img, i) if return_intermediates: return img, intermediates return img @torch.no_grad() def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None,**kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else list(map(lambda x: x[:batch_size], cond[key])) for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0) @torch.no_grad() def sample_ddpm(self, conditioning, batch_size=16, return_intermediates=False, x_T=None, verbose=True, timesteps=None, quantize_denoised=False, mask=None, x0=None, shape=None,**kwargs): if shape is None: shape = (batch_size, self.channels, self.image_size, self.image_size) elif len(shape) == 3: shape = (batch_size,) + shape if conditioning is not None: if isinstance(conditioning, dict): conditioning = {key: conditioning[key][:batch_size] if not isinstance(conditioning[key], list) else list(map(lambda x: x[:batch_size], conditioning[key])) for key in conditioning} else: conditioning = [c[:batch_size] for c in conditioning] if isinstance(conditioning, list) else conditioning[:batch_size] return self.p_sample_loop(conditioning, shape, return_intermediates=return_intermediates, x_T=x_T, verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, mask=mask, x0=x0) @torch.no_grad() def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): if ddim: ddim_sampler = DDIMSampler(self) shape = (self.channels, self.image_size, self.image_size) samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, shape,cond,verbose=False,**kwargs) else: samples, intermediates = self.sample(cond=cond, batch_size=batch_size, return_intermediates=True,**kwargs) return samples, intermediates @torch.no_grad() def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, quantize_denoised=True, plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, **kwargs): use_ddim = ddim_steps is not None log = dict() z, c, x, xrec, xc, phi_prev_list, phi_next_list = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, return_original_cond=True, return_phi=True, bs=N) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) log["inputs"] = x log["reconstruction"] = xrec if plot_diffusion_rows: # get diffusion row diffusion_row = list() z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: t = repeat(torch.tensor([t]), '1 -> b', b=n_row) t = t.to(self.device).long() noise = torch.randn_like(z_start) z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) diffusion_row.append(self.decode_first_stage(z_noisy, xc, phi_prev_list, phi_next_list)) diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) log["diffusion_row"] = diffusion_grid if sample: # get denoise row with self.ema_scope("Plotting"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta, x_T=None) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) x_samples = self.decode_first_stage(samples, xc, phi_prev_list, phi_next_list) log["samples"] = x_samples if plot_denoise_rows: denoise_grid = self._get_denoise_row_from_list(z_denoise_row, xc, phi_prev_list, phi_next_list) log["denoise_row"] = denoise_grid if quantize_denoised: # also display when quantizing x0 while sampling with self.ema_scope("Plotting Quantized Denoised"): samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, ddim_steps=ddim_steps,eta=ddim_eta, quantize_denoised=True, x_T=None) # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, # quantize_denoised=True) x_samples = self.decode_first_stage(samples.to(self.device), xc, phi_prev_list, phi_next_list) log["samples_x0_quantized"] = x_samples if plot_progressive_rows: with self.ema_scope("Plotting Progressives"): img, progressives = self.progressive_denoising(c, shape=(self.channels, self.image_size, self.image_size), batch_size=N, x_T=None) prog_row = self._get_denoise_row_from_list(progressives, xc, phi_prev_list, phi_next_list, desc="Progressive Generation") log["progressive_row"] = prog_row if return_keys: if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: return log else: return {key: log[key] for key in return_keys} return log def configure_optimizers(self): lr = self.learning_rate params = list(self.model.parameters()) if self.cond_stage_trainable: print(f"{self.__class__.__name__}: Also optimizing conditioner params!") params = params + list(self.cond_stage_model.parameters()) if self.learn_logvar: print('Diffusion model optimizing logvar') params.append(self.logvar) opt = torch.optim.AdamW(params, lr=lr) if self.use_scheduler: assert 'target' in self.scheduler_config scheduler = instantiate_from_config(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ { 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1 }] return [opt], scheduler return opt @torch.no_grad() def to_rgb(self, x): x = x.float() if not hasattr(self, "colorize"): self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) x = nn.functional.conv2d(x, weight=self.colorize) x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. return x ================================================ FILE: ldm/modules/attention.py ================================================ from inspect import isfunction import math import torch import torch.nn.functional as F from torch import nn, einsum from einops import rearrange, repeat from ldm.modules.diffusionmodules.util import checkpoint def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) return self.to_out(out) class SpatialSelfAttention(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = rearrange(q, 'b c h w -> b (h w) c') k = rearrange(k, 'b c h w -> b c (h w)') w_ = torch.einsum('bij,bjk->bik', q, k) w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = rearrange(v, 'b c h w -> b c (h w)') w_ = rearrange(w_, 'b i j -> b j i') h_ = torch.einsum('bij,bjk->bik', v, w_) h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) h_ = self.proj_out(h_) return x+h_ class CrossAttention(nn.Module): ''' Perform self-attention if context is None, else cross-attention. The dims of the input and output of the block are the same (arg query_dim). ''' def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): mask = rearrange(mask, 'b ... -> b (...)') max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=h) return self.to_out(out) class SpatialCrossAttention(nn.Module): ''' Cross-attention block for image-like data. First image reshape to b, t, d. Perform self-attention if context is None, else cross-attention. The dims of the input and output of the block are the same (arg query_dim). ''' def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) self.norm = nn.LayerNorm(query_dim) def forward(self, x, context=None): # re-arrange image data to b, t, d. b, c, h, w = x.shape x_in = x x = rearrange(x, 'b c h w -> b (h w) c') if (len(context.shape) == 4): context = rearrange(context, 'b c h w -> b (h w) c') heads = self.heads x = self.norm(x) q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=heads) out = self.to_out(out) # restore image shape out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) return x_in + out def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): ''' Borrowed from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py ''' _, dim, h, w, device, dtype = *patches.shape, patches.device, patches.dtype y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) omega = 1. / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) return pe.type(dtype) # (n, hd) class SpatialCrossAttentionWithPosEmb(nn.Module): ''' Cross-attention block for image-like data. First image reshape to b, t, d. Perform self-attention if context is None, else cross-attention. The dims of the input and output of the block are the same (arg query_dim). ''' def __init__(self, in_channels=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads self.scale = dim_head ** -0.5 self.heads = heads self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.to_q = nn.Linear(inner_dim, inner_dim, bias=False) self.to_k = nn.Linear(inner_dim, inner_dim, bias=False) self.to_v = nn.Linear(inner_dim, inner_dim, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, inner_dim), nn.Dropout(dropout) ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) self.norm = nn.LayerNorm(inner_dim) def forward(self, x, context=None): b, c, h, w = x.shape x_in = x context = default(context, x) x = self.proj_in(x) # (b,d,h,w) context = self.proj_in(context) # (b,d,h,w) # positional embedding pe = posemb_sincos_2d(x) # (n,d) # re-arrange image data to b, n, d. x = rearrange(x, 'b c h w -> b (h w) c') if (len(context.shape) == 4): context = rearrange(context, 'b c h w -> b (h w) c') # add pos emb x += pe if context.shape[1] != x.shape[1]: context[:,:h*w] += pe context[:,h*w:] += pe else: context += pe heads = self.heads x = self.norm(x) context = self.norm(context) q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=heads), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) out = rearrange(out, '(b h) n d -> b n (h d)', h=heads) out = self.to_out(out) # restore image shape out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) return x_in + out class BasicTransformerBlock(nn.Module): ''' Two CrossAttention modules followed by a fully connected layer. The first CrossAttention is applied to x in self-attention manner. The second CrossAttention is applied to x and context as cross attention. The fully connected layer has 4x internal dimention. The dims of the input and output of the block are the same (arg dim). ''' def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none self.norm1 = nn.LayerNorm(dim) self.norm2 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x # (8,4096,256) x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) to inner_dim (d) using conv1x1 Then reshape to b, t, d. Then apply standard transformer action (BasicTransformerBlock). Finally, reshape to image and pass to output conv1x1 layer, to restore the channel size of input. The dims of the input and output of the block are the same (arg in_channels). """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() self.in_channels = in_channels inner_dim = n_heads * d_head self.norm = Normalize(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for d in range(depth)] ) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention b, c, h, w = x.shape # here x and context might have different resolutions # because x is being downsampled while c is not x_in = x x = self.norm(x) x = self.proj_in(x) x = rearrange(x, 'b c h w -> b (h w) c') if (len(context.shape) == 4): context = rearrange(context, 'b c h w -> b (h w) c') for block in self.transformer_blocks: x = block(x, context=context) x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in ================================================ FILE: ldm/modules/diffusionmodules/__init__.py ================================================ ================================================ FILE: ldm/modules/diffusionmodules/model.py ================================================ # pytorch_diffusion + derived encoder decoder import math import torch import torch.nn as nn import numpy as np from ldm.modules.attention import LinearAttention, SpatialCrossAttentionWithPosEmb from ldm.modules.maxvit import SpatialCrossAttentionWithMax, MaxAttentionBlock from cupy_module import dsepconv def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb def nonlinearity(x): # swish return x*torch.sigmoid(x) def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) class IdentityWrapper(nn.Module): """ A wrapper for nn.Identity that allows additional input. """ def __init__(self) -> None: super().__init__() self.layer = nn.Identity() def forward(self, x, context=None): return self.layer(x) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x): if self.with_conv: pad = (0,1,0,1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) else: self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x, temb): h = x h = self.norm1(h) h = nonlinearity(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = self.norm2(h) h = nonlinearity(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x+h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h_ = x h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b,c,h,w = q.shape q = q.reshape(b,c,h*w) q = q.permute(0,2,1) # b,hw,c k = k.reshape(b,c,h*w) # b,c,hw w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b,c,h*w) w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = h_.reshape(b,c,h,w) h_ = self.proj_out(h_) return x+h_ def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none", 'max'], f'attn_type {attn_type} unknown' print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "none": return nn.Identity(in_channels) elif attn_type == 'max': return MaxAttentionBlock(in_channels, heads=1, dim_head=in_channels) else: return LinAttnBlock(in_channels) class FIEncoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch # 128 self.temb_ch = 0 self.num_resolutions = len(ch_mult) # 3 self.num_res_blocks = num_res_blocks # 2 self.resolution = resolution # 256 self.in_channels = in_channels # 3 # downsampling self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution # 256 in_ch_mult = (1,)+tuple(ch_mult) # (1,1,2,4) self.in_ch_mult = in_ch_mult # (1,1,2,4) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = int(ch*in_ch_mult[i_level]) block_out = int(ch*ch_mult[i_level]) for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn # if i_level != self.num_resolutions-1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, 2*z_channels if double_z else z_channels, # 3 kernel_size=3, stride=1, padding=1) def forward(self, x, ret_feature=False): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] phi_list = [] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) # if i_level != self.num_resolutions-1: hs.append(self.down[i_level].downsample(hs[-1])) phi_list.append(hs[-1]) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) if ret_feature: return h, phi_list return h class FlowEncoder(FIEncoder): def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs): super().__init__( ch=ch, out_ch=out_ch, ch_mult=ch_mult, num_res_blocks=num_res_blocks, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=in_channels, resolution=resolution, z_channels=z_channels, double_z=double_z, use_linear_attn=use_linear_attn, attn_type=attn_type, **ignore_kwargs ) class FlowDecoderWithResidual(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", num_head_channels=32, num_heads=1, cond_type=None, **ignorekwargs): super().__init__() def KernelHead(c_in): return torch.nn.Sequential( torch.nn.Conv2d(in_channels=c_in, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=32, out_channels=5, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), # torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), torch.nn.Conv2d(in_channels=5, out_channels=5, kernel_size=3, stride=1, padding=1) ) # end def OffsetHead(c_in): return torch.nn.Sequential( torch.nn.Conv2d(in_channels=c_in, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=32, out_channels=5 ** 2, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), # torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), torch.nn.Conv2d(in_channels=5 ** 2, out_channels=5 ** 2, kernel_size=3, stride=1, padding=1) ) def MaskHead(c_in): return torch.nn.Sequential( torch.nn.Conv2d(in_channels=c_in, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=32, out_channels=5 ** 2, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), # torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), torch.nn.Conv2d(in_channels=5 ** 2, out_channels=5 ** 2, kernel_size=3, stride=1, padding=1), torch.nn.Sigmoid() ) def ResidualHead(c_in): return torch.nn.Sequential( torch.nn.Conv2d(in_channels=c_in, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), torch.nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1), torch.nn.ReLU(inplace=False), # torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), torch.nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1) ) self.ch = ch # 128 self.temb_ch = 0 self.num_resolutions = len(ch_mult) # 3 self.num_res_blocks = num_res_blocks # 2 self.resolution = resolution # 256 self.in_channels = in_channels # 3 self.give_pre_end = give_pre_end # False self.tanh_out = tanh_out # False # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,)+tuple(ch_mult) # (1,1,2,4) block_in = int(ch*ch_mult[self.num_resolutions-1]) # 512 curr_res = resolution // 2**(self.num_resolutions-1) # 64 self.z_shape = (1,z_channels,curr_res,curr_res) # (1,3,64,64) print("Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): # 2,1,0 block = nn.ModuleList() attn = nn.ModuleList() block_out = int(ch*ch_mult[i_level]) # ResBlocks for i_block in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout)) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) # CrossAttention if num_head_channels == -1: dim_head = block_in // num_heads else: num_heads = block_in // num_head_channels dim_head = num_head_channels # 32 if cond_type == 'cross_attn': cross_attn = SpatialCrossAttentionWithPosEmb(in_channels=block_in, heads=num_heads, dim_head=dim_head) elif cond_type == 'max_cross_attn': cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in, heads=num_heads, dim_head=dim_head) elif cond_type == 'max_cross_attn_frame': cross_attn = SpatialCrossAttentionWithMax(in_channels=block_in, heads=num_heads, dim_head=dim_head, ctx_dim=6) else: cross_attn = IdentityWrapper() up = nn.Module() up.block = block up.attn = attn up.cross_attn = cross_attn # Upsample # if i_level != self.num_resolutions-1: ## THIS IS ORIGINAL CODE # if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d(block_in, block_in, kernel_size=3, stride=1, padding=1) self.moduleAlpha1 = OffsetHead(c_in=block_in) self.moduleAlpha2 = OffsetHead(c_in=block_in) self.moduleBeta1 = OffsetHead(c_in=block_in) self.moduleBeta2 = OffsetHead(c_in=block_in) self.moduleKernelHorizontal1 = KernelHead(c_in=block_in) self.moduleKernelHorizontal2 = KernelHead(c_in=block_in) self.moduleKernelVertical1 = KernelHead(c_in=block_in) self.moduleKernelVertical2 = KernelHead(c_in=block_in) self.moduleMask = MaskHead(c_in=block_in) self.moduleResidual = ResidualHead(c_in=block_in) self.modulePad = torch.nn.ReplicationPad2d([2, 2, 2, 2]) def forward(self, z, cond_dict): phi_prev_list = cond_dict['phi_prev_list'] phi_next_list = cond_dict['phi_next_list'] frame_prev = cond_dict['frame_prev'] frame_next = cond_dict['frame_next'] #assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): # [2,1,0] for i_block in range(self.num_res_blocks): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) ctx = None if phi_prev_list[i_level] is not None: ctx = torch.cat([phi_prev_list[i_level], phi_next_list[i_level]], dim=1) h = self.up[i_level].cross_attn(h, ctx) # if i_level != self.num_resolutions-1: # if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = nonlinearity(h) h = self.conv_out(h) alpha1 = self.moduleAlpha1(h) alpha2 = self.moduleAlpha2(h) beta1 = self.moduleBeta1(h) beta2 = self.moduleBeta2(h) v1 = self.moduleKernelVertical1(h) v2 = self.moduleKernelVertical2(h) h1 = self.moduleKernelHorizontal1(h) h2 = self.moduleKernelHorizontal2(h) mask1 = self.moduleMask(h) mask2 = 1.0 - mask1 warped1 = dsepconv.FunctionDSepconv(self.modulePad(frame_prev), v1, h1, alpha1, beta1, mask1) warped2 = dsepconv.FunctionDSepconv(self.modulePad(frame_next), v2, h2, alpha2, beta2, mask2) warped = warped1 + warped2 out = warped + self.moduleResidual(h) return out ================================================ FILE: ldm/modules/diffusionmodules/openaimodel.py ================================================ from abc import abstractmethod from functools import partial import math from typing import Iterable import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F from ldm.modules.maxvit import SpatialTransformerWithMax, MaxAttentionBlock from ldm.modules.diffusionmodules.util import ( checkpoint, conv_nd, linear, avg_pool_nd, zero_module, normalization, timestep_embedding, ) from ldm.modules.attention import SpatialTransformer # dummy replace def convert_module_to_f16(x): pass def convert_module_to_f32(x): pass ## go class AttentionPool2d(nn.Module): """ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py """ def __init__( self, spacial_dim: int, embed_dim: int, num_heads_channels: int, output_dim: int = None, ): super().__init__() self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) self.num_heads = embed_dim // num_heads_channels self.attention = QKVAttention(self.num_heads) def forward(self, x): b, c, *_spatial = x.shape x = x.reshape(b, c, -1) # NC(HW) x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) x = self.qkv_proj(x) x = self.attention(x) x = self.c_proj(x) return x[:, :, 0] class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. """ @abstractmethod def forward(self, x, emb): """ Apply the module to `x` given `emb` timestep embeddings. """ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer) or isinstance(layer, SpatialTransformerWithMax): x = layer(x, context) else: x = layer(x) return x class Upsample(nn.Module): """ An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims if use_conv: self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) def forward(self, x): assert x.shape[1] == self.channels if self.dims == 3: x = F.interpolate( x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" ) else: x = F.interpolate(x, scale_factor=2, mode="nearest") if self.use_conv: x = self.conv(x) return x class TransposedUpsample(nn.Module): 'Learned 2x upsampling without padding' def __init__(self, channels, out_channels=None, ks=5): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) def forward(self,x): return self.up(x) class Downsample(nn.Module): """ A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.dims = dims stride = 2 if dims != 3 else (1, 2, 2) if use_conv: self.op = conv_nd( dims, self.channels, self.out_channels, 3, stride=stride, padding=padding ) else: assert self.channels == self.out_channels self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) def forward(self, x): assert x.shape[1] == self.channels return self.op(x) class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use a spatial convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for downsampling. """ def __init__( self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False, dims=2, use_checkpoint=False, up=False, down=False, ): super().__init__() self.channels = channels # 256 self.emb_channels = emb_channels # 1024 self.dropout = dropout # 0 self.out_channels = out_channels or channels # 256 self.use_conv = use_conv # False self.use_checkpoint = use_checkpoint # False self.use_scale_shift_norm = use_scale_shift_norm # False self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down if up: self.h_upd = Upsample(channels, False, dims) self.x_upd = Upsample(channels, False, dims) elif down: self.h_upd = Downsample(channels, False, dims) self.x_upd = Downsample(channels, False, dims) else: self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( nn.SiLU(), linear( emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = conv_nd( dims, channels, self.out_channels, 3, padding=1 ) else: self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) def forward(self, x, emb): """ Apply the block to a Tensor, conditioned on a timestep embedding. :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ return checkpoint( self._forward, (x, emb), self.parameters(), self.use_checkpoint ) def _forward(self, x, emb): if self.updown: in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] h = in_rest(x) h = self.h_upd(h) x = self.x_upd(x) h = in_conv(h) else: h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = th.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ def __init__( self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False, use_new_attention_order=False, ): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert ( channels % num_head_channels == 0 ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint self.norm = normalization(channels) self.qkv = conv_nd(1, channels, channels * 3, 1) if use_new_attention_order: # split qkv before split heads self.attention = QKVAttention(self.num_heads) else: # split heads before split qkv self.attention = QKVAttentionLegacy(self.num_heads) self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) def forward(self, x): return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! #return pt_checkpoint(self._forward, x) # pytorch def _forward(self, x): b, c, *spatial = x.shape x = x.reshape(b, c, -1) qkv = self.qkv(self.norm(x)) h = self.attention(qkv) h = self.proj_out(h) return (x + h).reshape(b, c, *spatial) def count_flops_attn(model, _x, y): """ A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: macs, params = thop.profile( model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, ) """ b, c, *spatial = y[0].shape num_spatial = int(np.prod(spatial)) # We perform two matmuls with the same number of ops. # The first computes the weight matrix, the second computes # the combination of the value vectors. matmul_ops = 2 * b * (num_spatial ** 2) * c model.total_ops += th.DoubleTensor([matmul_ops]) class QKVAttentionLegacy(nn.Module): """ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): """ Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", q * scale, k * scale ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v) return a.reshape(bs, -1, length) @staticmethod def count_flops(model, _x, y): return count_flops_attn(model, _x, y) class QKVAttention(nn.Module): """ A module which performs QKV attention and splits in a different order. """ def __init__(self, n_heads): super().__init__() self.n_heads = n_heads def forward(self, qkv): """ Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 ch = width // (3 * self.n_heads) q, k, v = qkv.chunk(3, dim=1) scale = 1 / math.sqrt(math.sqrt(ch)) weight = th.einsum( "bct,bcs->bts", (q * scale).view(bs * self.n_heads, ch, length), (k * scale).view(bs * self.n_heads, ch, length), ) # More stable with f16 than dividing afterwards weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) return a.reshape(bs, -1, length) @staticmethod def count_flops(model, _x, y): return count_flops_attn(model, _x, y) class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x downsampling, attention will be used. :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param conv_resample: if True, use learned convolutions for upsampling and downsampling. :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this model will be class-conditional with `num_classes` classes. :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, num_classes=None, use_checkpoint=False, use_fp16=False, num_heads=-1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, use_max_self_attn=False, use_spatial_transformer=False, # custom transformer support use_max_spatial_transfomer=False, transformer_depth=1, # custom transformer support context_dim=None, # custom transformer support n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model legacy=True, ): super().__init__() if use_spatial_transformer: assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' if context_dim is not None: assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' from omegaconf.listconfig import ListConfig if type(context_dim) == ListConfig: context_dim = list(context_dim) if num_heads_upsample == -1: num_heads_upsample = num_heads if num_heads == -1: assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' if num_head_channels == -1: assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' self.image_size = image_size # 32 self.in_channels = in_channels # 3 self.model_channels = model_channels # 256 self.out_channels = out_channels # 3 self.num_res_blocks = num_res_blocks # 2 self.attention_resolutions = attention_resolutions # [4,2,1] self.dropout = dropout # 0 self.channel_mult = channel_mult # [1,2,4] self.conv_resample = conv_resample # True self.num_classes = num_classes # None self.use_checkpoint = use_checkpoint # False self.dtype = th.float16 if use_fp16 else th.float32 # float32 self.num_heads = num_heads # -1 self.num_head_channels = num_head_channels # 32 self.num_heads_upsample = num_heads_upsample # -1 self.predict_codebook_ids = n_embed is not None # False time_embed_dim = model_channels * 4 # 256*4 = 1024 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels # 256 input_block_chans = [model_channels] # [256,] ch = model_channels # 256 ds = 1 max_self_attn_ws = min(self.image_size // 4, 8) for level, mult in enumerate(channel_mult): # [1,2,4] for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels # 32 if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_max_self_attn and not use_spatial_transformer else MaxAttentionBlock( ch, num_heads, dim_head, window_size=max_self_attn_ws ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) if not use_max_spatial_transfomer else SpatialTransformerWithMax( ch, num_heads, dim_head, context_dim=context_dim ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_max_self_attn and not use_spatial_transformer else MaxAttentionBlock( ch, num_heads, dim_head, window_size=max_self_attn_ws ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) if not use_max_spatial_transfomer else SpatialTransformerWithMax( ch, num_heads, dim_head, context_dim=context_dim ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: # [4,2,1] for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ ResBlock( ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = model_channels * mult if ds in attention_resolutions: if num_head_channels == -1: dim_head = ch // num_heads else: num_heads = ch // num_head_channels dim_head = num_head_channels if legacy: #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads_upsample, num_head_channels=dim_head, use_new_attention_order=use_new_attention_order, ) if not use_max_self_attn and not use_spatial_transformer else MaxAttentionBlock( ch, num_heads, dim_head, window_size=max_self_attn_ws ) if not use_spatial_transformer else SpatialTransformer( ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim ) if not use_max_spatial_transfomer else SpatialTransformerWithMax( ch, num_heads, dim_head, context_dim=context_dim ) ) if level and i == num_res_blocks: out_ch = ch layers.append( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, up=True, ) if resblock_updown else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) ) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch self.out = nn.Sequential( normalization(ch), nn.SiLU(), zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), ) if self.predict_codebook_ids: self.id_predictor = nn.Sequential( normalization(ch), conv_nd(dims, model_channels, n_embed, 1), #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) self.output_blocks.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) def forward(self, x, timesteps=None, context=None, y=None,**kwargs): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) emb = self.time_embed(t_emb) # (8,1024) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = th.cat([h, hs.pop()], dim=1) h = module(h, emb, context) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) else: return self.out(h) class EncoderUNetModel(nn.Module): """ The half UNet model with attention and timestep embedding. For usage, see UNet. """ def __init__( self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, dims=2, use_checkpoint=False, use_fp16=False, num_heads=1, num_head_channels=-1, num_heads_upsample=-1, use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, pool="adaptive", *args, **kwargs ): super().__init__() if num_heads_upsample == -1: num_heads_upsample = num_heads self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.attention_resolutions = attention_resolutions self.dropout = dropout self.channel_mult = channel_mult self.conv_resample = conv_resample self.use_checkpoint = use_checkpoint self.dtype = th.float16 if use_fp16 else th.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( [ TimestepEmbedSequential( conv_nd(dims, in_channels, model_channels, 3, padding=1) ) ] ) self._feature_size = model_channels input_block_chans = [model_channels] ch = model_channels ds = 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ ResBlock( ch, time_embed_dim, dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ) ] ch = mult * model_channels if ds in attention_resolutions: layers.append( AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, down=True, ) if resblock_updown else Downsample( ch, conv_resample, dims=dims, out_channels=out_ch ) ) ) ch = out_ch input_block_chans.append(ch) ds *= 2 self._feature_size += ch self.middle_block = TimestepEmbedSequential( ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), AttentionBlock( ch, use_checkpoint=use_checkpoint, num_heads=num_heads, num_head_channels=num_head_channels, use_new_attention_order=use_new_attention_order, ), ResBlock( ch, time_embed_dim, dropout, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), ) self._feature_size += ch self.pool = pool if pool == "adaptive": self.out = nn.Sequential( normalization(ch), nn.SiLU(), nn.AdaptiveAvgPool2d((1, 1)), zero_module(conv_nd(dims, ch, out_channels, 1)), nn.Flatten(), ) elif pool == "attention": assert num_head_channels != -1 self.out = nn.Sequential( normalization(ch), nn.SiLU(), AttentionPool2d( (image_size // ds), ch, num_head_channels, out_channels ), ) elif pool == "spatial": self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), nn.ReLU(), nn.Linear(2048, self.out_channels), ) elif pool == "spatial_v2": self.out = nn.Sequential( nn.Linear(self._feature_size, 2048), normalization(2048), nn.SiLU(), nn.Linear(2048, self.out_channels), ) else: raise NotImplementedError(f"Unexpected {pool} pooling") def convert_to_fp16(self): """ Convert the torso of the model to float16. """ self.input_blocks.apply(convert_module_to_f16) self.middle_block.apply(convert_module_to_f16) def convert_to_fp32(self): """ Convert the torso of the model to float32. """ self.input_blocks.apply(convert_module_to_f32) self.middle_block.apply(convert_module_to_f32) def forward(self, x, timesteps): """ Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :return: an [N x K] Tensor of outputs. """ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) results = [] h = x.type(self.dtype) for module in self.input_blocks: h = module(h, emb) if self.pool.startswith("spatial"): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = self.middle_block(h, emb) if self.pool.startswith("spatial"): results.append(h.type(x.dtype).mean(dim=(2, 3))) h = th.cat(results, axis=-1) return self.out(h) else: h = h.type(x.dtype) return self.out(h) ================================================ FILE: ldm/modules/diffusionmodules/util.py ================================================ # adopted from # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py # and # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py # and # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py # # thanks! import os import math import torch import torch.nn as nn import numpy as np from einops import repeat from ldm.util import instantiate_from_config def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 ) elif schedule == "cosine": timesteps = ( torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s ) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2) alphas = alphas / alphas[0] betas = 1 - alphas[1:] / alphas[:-1] betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == "sqrt_linear": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) elif schedule == "sqrt": betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): if ddim_discr_method == 'uniform': c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: print(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # select alphas for computing the variance schedule alphas = alphacums[ddim_timesteps] alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(f'For the chosen value of eta, which is {eta}, ' f'this results in the following sigma_t schedule for ddim sampler {sigmas}') return sigmas, alphas, alphas_prev def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): """ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of (1-beta) over time from t = [0,1]. :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t from 0 to 1 and produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ betas = [] for i in range(num_diffusion_timesteps): t1 = i / num_diffusion_timesteps t2 = (i + 1) / num_diffusion_timesteps betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) return np.array(betas) def extract_into_tensor(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def checkpoint(func, inputs, params, flag): """ Evaluate a function without caching intermediate activations, allowing for reduced memory at the expense of extra compute in the backward pass. :param func: the function to evaluate. :param inputs: the argument sequence to pass to `func`. :param params: a sequence of parameters `func` depends on but does not explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ if flag: args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: return func(*inputs) class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): ctx.run_function = run_function ctx.input_tensors = list(args[:length]) ctx.input_params = list(args[length:]) with torch.no_grad(): output_tensors = ctx.run_function(*ctx.input_tensors) return output_tensors @staticmethod def backward(ctx, *output_grads): ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] with torch.enable_grad(): # Fixes a bug where the first op in run_function modifies the # Tensor storage in place, which is not allowed for detach()'d # Tensors. shallow_copies = [x.view_as(x) for x in ctx.input_tensors] output_tensors = ctx.run_function(*shallow_copies) input_grads = torch.autograd.grad( output_tensors, ctx.input_tensors + ctx.input_params, output_grads, allow_unused=True, ) del ctx.input_tensors del ctx.input_params del output_tensors return (None, None) + input_grads def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): """ Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ if not repeat_only: half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) else: embedding = repeat(timesteps, 'b -> b d', d=dim) return embedding def zero_module(module): """ Zero out the parameters of a module and return it. """ for p in module.parameters(): p.detach().zero_() return module def scale_module(module, scale): """ Scale the parameters of a module and return it. """ for p in module.parameters(): p.detach().mul_(scale) return module def mean_flat(tensor): """ Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def normalization(channels): """ Make a standard normalization layer. :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(32, channels) # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. class SiLU(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class GroupNorm32(nn.GroupNorm): def forward(self, x): return super().forward(x.float()).type(x.dtype) def conv_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. """ if dims == 1: return nn.Conv1d(*args, **kwargs) elif dims == 2: return nn.Conv2d(*args, **kwargs) elif dims == 3: return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") def linear(*args, **kwargs): """ Create a linear module. """ return nn.Linear(*args, **kwargs) def avg_pool_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D average pooling module. """ if dims == 1: return nn.AvgPool1d(*args, **kwargs) elif dims == 2: return nn.AvgPool2d(*args, **kwargs) elif dims == 3: return nn.AvgPool3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") class HybridConditioner(nn.Module): def __init__(self, c_concat_config, c_crossattn_config): super().__init__() self.concat_conditioner = instantiate_from_config(c_concat_config) self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) def forward(self, c_concat, c_crossattn): c_concat = self.concat_conditioner(c_concat) c_crossattn = self.crossattn_conditioner(c_crossattn) return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) return repeat_noise() if repeat else noise() ================================================ FILE: ldm/modules/ema.py ================================================ import torch from torch import nn class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates else torch.tensor(-1,dtype=torch.int)) for name, p in model.named_parameters(): if p.requires_grad: #remove as '.'-character is not allowed in buffers s_name = name.replace('.','') self.m_name2s_name.update({name:s_name}) self.register_buffer(s_name,p.clone().detach().data) self.collected_params = [] def forward(self,model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: assert not key in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: assert not key in self.m_name2s_name def store(self, parameters): """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) ================================================ FILE: ldm/modules/losses/__init__.py ================================================ ================================================ FILE: ldm/modules/losses/vqperceptual.py ================================================ import torch from torch import nn import torch.nn.functional as F from einops import repeat from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: weight = value return weight def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): return torch.abs(x-y) def l2(x, y): return torch.pow((x-y), 2) class VQLPIPSWithDiscriminator(nn.Module): def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", pixel_loss="l1"): super().__init__() assert disc_loss in ["hinge", "vanilla"] assert perceptual_loss in ["lpips", "clips", "dists"] assert pixel_loss in ["l1", "l2"] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight if perceptual_loss == "lpips": print(f"{self.__class__.__name__}: Running with LPIPS.") self.perceptual_loss = LPIPS().eval() else: raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") self.perceptual_weight = perceptual_weight if pixel_loss == "l1": self.pixel_loss = l1 else: self.pixel_loss = l2 self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf ).apply(weights_init) self.discriminator_iter_start = disc_start if disc_loss == "hinge": self.disc_loss = hinge_d_loss elif disc_loss == "vanilla": self.disc_loss = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{disc_loss}'.") print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional self.n_classes = n_classes def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] else: nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, cond=None, split="train", predicted_indices=None): #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part if optimizer_idx == 0: # generator update if cond is None: assert not self.disc_conditional logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) g_loss = -torch.mean(logits_fake) try: d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/quant_loss".format(split): codebook_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): rec_loss.detach().mean(), "{}/p_loss".format(split): p_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) log[f"{split}/perplexity"] = perplexity log[f"{split}/cluster_usage"] = cluster_usage return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) else: logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), "{}/logits_real".format(split): logits_real.detach().mean(), "{}/logits_fake".format(split): logits_fake.detach().mean() } return d_loss, log ================================================ FILE: ldm/modules/maxvit.py ================================================ import torch from torch import nn, einsum import torch.nn.functional from einops import rearrange, repeat from einops.layers.torch import Rearrange, Reduce from inspect import isfunction # Code adapted from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/max_vit.py def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class PreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() self.norm = nn.LayerNorm(dim) self.fn = fn def forward(self, x, c=None): if exists(c): return self.fn(self.norm(x), self.norm(c)) + x return self.fn(self.norm(x)) + x class SqueezeExcitation(nn.Module): def __init__(self, dim, shrinkage_rate = 0.25): super().__init__() hidden_dim = int(dim * shrinkage_rate) self.gate = nn.Sequential( Reduce('b c h w -> b c', 'mean'), nn.Linear(dim, hidden_dim, bias = False), nn.SiLU(), nn.Linear(hidden_dim, dim, bias = False), nn.Sigmoid(), Rearrange('b c -> b c 1 1') ) def forward(self, x): return x * self.gate(x) class FeedForward(nn.Module): def __init__(self, dim, mult = 4, dropout = 0.): super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__( self, dim, dim_head = 32, dropout = 0., window_size = 7 ): super().__init__() assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head' self.heads = dim // dim_head self.scale = dim_head ** -0.5 self.to_q = nn.Linear(dim, dim, bias = False) self.to_k = nn.Linear(dim, dim, bias = False) self.to_v = nn.Linear(dim, dim, bias = False) self.attend = nn.Sequential( nn.Softmax(dim = -1), nn.Dropout(dropout) ) self.to_out = nn.Sequential( nn.Linear(dim, dim, bias = False), nn.Dropout(dropout) ) # relative positional bias self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) pos = torch.arange(window_size) grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) grid = rearrange(grid, 'c i j -> (i j) c') rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') rel_pos += window_size - 1 rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) def forward(self, x, c=None): c = default(c, x) batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads # flatten x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d') c = rearrange(c, 'b x y w1 w2 d -> (b x y) (w1 w2) d') # project for queries, keys, values q = self.to_q(x) k = self.to_k(c) v = self.to_v(c) # split heads q, k, v = map(lambda t: rearrange(t, 'b n (h d ) -> b h n d', h = h), (q, k, v)) # scale q = q * self.scale # sim sim = einsum('b h i d, b h j d -> b h i j', q, k) # add positional bias bias = self.rel_pos_bias(self.rel_pos_indices) sim = sim + rearrange(bias, 'i j h -> h i j') # attention attn = self.attend(sim) # aggregate out = einsum('b h i j, b h j d -> b h i d', attn, v) # merge heads out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) # combine heads out out = self.to_out(out) return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) class Dropsample(nn.Module): def __init__(self, prob = 0): super().__init__() self.prob = prob def forward(self, x): device = x.device if self.prob == 0. or (not self.training): return x keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob return x * keep_mask / (1 - self.prob) class MBConvResidual(nn.Module): def __init__(self, fn, dropout = 0.): super().__init__() self.fn = fn self.dropsample = Dropsample(dropout) def forward(self, x): out = self.fn(x) out = self.dropsample(out) return out + x def MBConv( dim_in, dim_out, *, downsample, expansion_rate = 4, shrinkage_rate = 0.25, dropout = 0. ): hidden_dim = int(expansion_rate * dim_out) stride = 2 if downsample else 1 net = nn.Sequential( nn.Conv2d(dim_in, hidden_dim, 1), nn.BatchNorm2d(hidden_dim), nn.GELU(), nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim), nn.BatchNorm2d(hidden_dim), nn.GELU(), SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate), nn.Conv2d(hidden_dim, dim_out, 1), nn.BatchNorm2d(dim_out) ) if dim_in == dim_out and not downsample: net = MBConvResidual(net, dropout = dropout) return net class MaxAttentionBlock(nn.Module): def __init__(self, in_channels, heads=8, dim_head=64, dropout=0., window_size=8): super().__init__() w = window_size layer_dim = dim_head * heads self.rearrange_block_in = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention self.attn_block = PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)) self.ff_block = PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)) self.rearrange_block_out = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)') self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention self.attn_grid = PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)) self.ff_grid = PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)) self.rearrange_grid_out = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)') def forward(self, x): # block attention x = self.rearrange_block_in(x) x = self.attn_block(x) x = self.ff_block(x) x = self.rearrange_block_out(x) # grid attention x = self.rearrange_grid_in(x) x = self.attn_grid(x) x = self.ff_grid(x) x = self.rearrange_grid_out(x) ## output stage return x class SpatialCrossAttentionWithMax(nn.Module): def __init__(self, in_channels, heads=8, dim_head=64, ctx_dim=None, dropout=0., window_size=8): super().__init__() w = window_size layer_dim = dim_head * heads if ctx_dim == None: self.proj_in = MBConv(layer_dim*2, layer_dim, downsample=False) else: self.proj_in = MBConv(ctx_dim, layer_dim, downsample=False) self.rearrange_block_in = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention self.attn_block = PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)) self.ff_block = PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)) self.rearrange_block_out = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)') self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention self.attn_grid = PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)) self.ff_grid = PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)) self.rearrange_grid_out = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)') self.out_conv = nn.Sequential( SqueezeExcitation(dim=layer_dim*2), nn.Conv2d(layer_dim*2, layer_dim, kernel_size=3, padding=1) ) def forward(self, x, context=None): context = default(context, x) # MBConv c = self.proj_in(context) # block attention x = self.rearrange_block_in(x) c = self.rearrange_block_in(c) x = self.attn_block(x, c) x = self.ff_block(x) x = self.rearrange_block_out(x) c = self.rearrange_block_out(c) # grid attention x = self.rearrange_grid_in(x) c = self.rearrange_grid_in(c) x = self.attn_grid(x, c) x = self.ff_grid(x) x = self.rearrange_grid_out(x) return x class SpatialTransformerWithMax(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) to inner_dim (d) using conv1x1 Then reshape to b, t, d. Then apply standard transformer action (BasicTransformerBlock). Finally, reshape to image and pass to output conv1x1 layer, to restore the channel size of input. The dims of the input and output of the block are the same (arg in_channels). """ def __init__(self, in_channels, n_heads, d_head, dropout=0., context_dim=None, w=2): super().__init__() self.in_channels = in_channels self.context_dim = context_dim inner_dim = n_heads * d_head self.proj_in = MBConv(context_dim, inner_dim, downsample=False) self.rearrange_block_in = Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w) # block-like attention self.attn_block = PreNormResidual(inner_dim, Attention(dim = inner_dim, dim_head = d_head, dropout = dropout, window_size = w)) self.ff_block = PreNormResidual(inner_dim, FeedForward(dim = inner_dim, dropout = dropout)) self.rearrange_block_out = Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)') self.rearrange_grid_in = Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w) # grid-like attention self.attn_grid = PreNormResidual(inner_dim, Attention(dim = inner_dim, dim_head = d_head, dropout = dropout, window_size = w)) self.ff_grid = PreNormResidual(inner_dim, FeedForward(dim = inner_dim, dropout = dropout)) self.rearrange_grid_out = Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)') def forward(self, x, context=None): context = default(context, x) # down sample context if necessary # this is due to the implementation of max crossattn here if context.shape[2] != x.shape[2]: stride = context.shape[2] // x.shape[2] context = torch.nn.functional.avg_pool2d(context, kernel_size=stride, stride=stride) # MBConv c = self.proj_in(context) # block attention x = self.rearrange_block_in(x) c = self.rearrange_block_in(c) x = self.attn_block(x, c) x = self.ff_block(x) x = self.rearrange_block_out(x) c = self.rearrange_block_out(c) # grid attention x = self.rearrange_grid_in(x) c = self.rearrange_grid_in(c) x = self.attn_grid(x, c) x = self.ff_grid(x) x = self.rearrange_grid_out(x) return x ================================================ FILE: ldm/util.py ================================================ import importlib import torch import numpy as np from collections import abc from einops import rearrange from functools import partial import multiprocessing as mp from threading import Thread from queue import Queue from inspect import isfunction from PIL import Image, ImageDraw, ImageFont def log_txt_as_img(wh, xc, size=10): # wh a tuple of (width, height) # xc a list of captions to plot b = len(xc) txts = list() for bi in range(b): txt = Image.new("RGB", wh, color="white") draw = ImageDraw.Draw(txt) font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) nc = int(40 * (wh[0] / 256)) lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) txts = np.stack(txts) txts = torch.tensor(txts) return txts def ismap(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] > 3) def isimage(x): if not isinstance(x, torch.Tensor): return False return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def mean_flat(tensor): """ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 Take the mean over all non-batch dimensions. """ return tensor.mean(dim=list(range(1, len(tensor.shape)))) def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params def instantiate_from_config(config): if not "target" in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): # create dummy dataset instance # run prefetching if idx_to_fn: res = func(data, worker_id=idx) else: res = func(data) Q.put([idx, res]) Q.put("Done") def parallel_data_prefetch( func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False ): # if target_data_type not in ["ndarray", "list"]: # raise ValueError( # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." # ) if isinstance(data, np.ndarray) and target_data_type == "list": raise ValueError("list expected but function got ndarray.") elif isinstance(data, abc.Iterable): if isinstance(data, dict): print( f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' ) data = list(data.values()) if target_data_type == "ndarray": data = np.asarray(data) else: data = list(data) else: raise TypeError( f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." ) if cpu_intensive: Q = mp.Queue(1000) proc = mp.Process else: Q = Queue(1000) proc = Thread # spawn processes if target_data_type == "ndarray": arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate(np.array_split(data, n_proc)) ] else: step = ( int(len(data) / n_proc + 1) if len(data) % n_proc != 0 else int(len(data) / n_proc) ) arguments = [ [func, Q, part, i, use_worker_id] for i, part in enumerate( [data[i: i + step] for i in range(0, len(data), step)] ) ] processes = [] for i in range(n_proc): p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) processes += [p] # start processes print(f"Start prefetching...") import time start = time.time() gather_res = [[] for _ in range(n_proc)] try: for p in processes: p.start() k = 0 while k < n_proc: # get result res = Q.get() if res == "Done": k += 1 else: gather_res[res[0]] = res[1] except Exception as e: print("Exception: ", e) for p in processes: p.terminate() raise e finally: for p in processes: p.join() print(f"Prefetching complete. [{time.time() - start} sec.]") if target_data_type == 'ndarray': if not isinstance(gather_res[0], np.ndarray): return np.concatenate([np.asarray(r) for r in gather_res], axis=0) # order outputs return np.concatenate(gather_res, axis=0) elif target_data_type == 'list': out = [] for r in gather_res: out.extend(r) return out else: return gather_res ================================================ FILE: main.py ================================================ import argparse, os, sys, datetime, glob import numpy as np import time import torch import torchvision import pytorch_lightning as pl from packaging import version from omegaconf import OmegaConf from torch.utils.data import DataLoader, Dataset from functools import partial from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.trainer import Trainer from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import rank_zero_info from ldm.util import instantiate_from_config def get_parser(**parser_kwargs): def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") parser = argparse.ArgumentParser(**parser_kwargs) parser.add_argument( "-n", "--name", type=str, const=True, default="", nargs="?", help="postfix for logdir", ) parser.add_argument( "-r", "--resume", type=str, const=True, default="", nargs="?", help="resume from checkpoint", ) parser.add_argument( "-b", "--base", nargs="*", metavar="base_config.yaml", help="paths to base configs. Loaded from left-to-right. " "Parameters can be overwritten or added with command-line options of the form `--key value`.", default=list(), ) parser.add_argument( "-t", "--train", type=str2bool, const=True, default=False, nargs="?", help="train", ) parser.add_argument( "--no-test", type=str2bool, const=True, default=False, nargs="?", help="disable test", ) parser.add_argument( "-p", "--project", help="name of new or path to existing project" ) parser.add_argument( "-d", "--debug", type=str2bool, nargs="?", const=True, default=False, help="enable post-mortem debugging", ) parser.add_argument( "-s", "--seed", type=int, default=23, help="seed for seed_everything", ) parser.add_argument( "-f", "--postfix", type=str, default="", help="post-postfix for default name", ) parser.add_argument( "-l", "--logdir", type=str, default="logs", help="directory for logging dat shit", ) parser.add_argument( "--scale_lr", type=str2bool, nargs="?", const=True, default=True, help="scale base-lr by ngpu * batch_size * n_accumulate", ) return parser def nondefault_trainer_args(opt): parser = argparse.ArgumentParser() parser = Trainer.add_argparse_args(parser) args = parser.parse_args([]) return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" def __init__(self, dataset): self.data = dataset def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def worker_init_fn(_): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset worker_id = worker_info.id return np.random.seed(np.random.get_state()[1][0] + worker_id) class DataModuleFromConfig(pl.LightningDataModule): def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, shuffle_val_dataloader=False): super().__init__() self.batch_size = batch_size self.dataset_configs = dict() self.num_workers = num_workers if num_workers is not None else batch_size * 2 self.use_worker_init_fn = use_worker_init_fn if train is not None: self.dataset_configs["train"] = train self.train_dataloader = self._train_dataloader if validation is not None: self.dataset_configs["validation"] = validation self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) if test is not None: self.dataset_configs["test"] = test self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) if predict is not None: self.dataset_configs["predict"] = predict self.predict_dataloader = self._predict_dataloader self.wrap = wrap def prepare_data(self): for data_cfg in self.dataset_configs.values(): instantiate_from_config(data_cfg) def setup(self, stage=None): self.datasets = dict( (k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) if self.wrap: for k in self.datasets: self.datasets[k] = WrappedDataset(self.datasets[k]) def _train_dataloader(self): if self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, worker_init_fn=init_fn) def _val_dataloader(self, shuffle=False): if self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["validation"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) def _test_dataloader(self, shuffle=False): if self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["test"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) def _predict_dataloader(self, shuffle=False): if self.use_worker_init_fn: init_fn = worker_init_fn else: init_fn = None return DataLoader(self.datasets["predict"], batch_size=self.batch_size, num_workers=self.num_workers, worker_init_fn=init_fn) class SetupCallback(Callback): def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): super().__init__() self.resume = resume self.now = now self.logdir = logdir self.ckptdir = ckptdir self.cfgdir = cfgdir self.config = config self.lightning_config = lightning_config def on_keyboard_interrupt(self, trainer, pl_module): if trainer.global_rank == 0: print("Summoning checkpoint.") ckpt_path = os.path.join(self.ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def on_fit_start(self, trainer, pl_module): if trainer.global_rank == 0: # Create logdirs and save configs os.makedirs(self.logdir, exist_ok=True) os.makedirs(self.ckptdir, exist_ok=True) os.makedirs(self.cfgdir, exist_ok=True) if "callbacks" in self.lightning_config: if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']: os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True) print("Project config") print(OmegaConf.to_yaml(self.config)) OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) print("Lightning config") print(OmegaConf.to_yaml(self.lightning_config)) OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}), os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now))) else: # ModelCheckpoint callback created log directory --- remove it if not self.resume and os.path.exists(self.logdir): dst, name = os.path.split(self.logdir) dst = os.path.join(dst, "child_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) try: os.rename(self.logdir, dst) except FileNotFoundError: pass class ImageLogger(Callback): def __init__(self, batch_frequency, val_batch_frequency, max_images, clamp=True, increase_log_steps=True, rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False, log_images_kwargs=None): super().__init__() self.rescale = rescale self.batch_freq = batch_frequency self.val_batch_frequency = val_batch_frequency self.max_images = max_images self.logger_log_images = { pl.loggers.TensorBoardLogger: self._testtube, } self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: self.log_steps = [self.batch_freq] self.clamp = clamp self.disabled = disabled self.log_on_batch_idx = log_on_batch_idx self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} self.log_first_step = log_first_step self.val_psnr_epoch = [] @rank_zero_only def _testtube(self, pl_module, images, batch_idx, split): for k in images: grid = torchvision.utils.make_grid(images[k]) grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w tag = f"{split}/{k}" pl_module.logger.experiment.add_image( tag, grid, global_step=pl_module.global_step) @rank_zero_only def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): root = os.path.join(save_dir, "images", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) if self.rescale: grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format( k, global_step, current_epoch, batch_idx) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) Image.fromarray(grid).save(path) def log_img(self, pl_module, batch, batch_idx, split="train"): log = False check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step if (split == 'train' and self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0 hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): log = True elif (split == 'val' and (batch_idx % self.val_batch_frequency) == 0 and hasattr(pl_module, "log_images") and callable(pl_module.log_images) and self.max_images > 0): log = True if log: logger = type(pl_module.logger) is_train = pl_module.training if is_train: pl_module.eval() with torch.no_grad(): images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) for k in images: N = min(images[k].shape[0], self.max_images) images[k] = images[k][:N] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() if self.clamp: images[k] = torch.clamp(images[k], -1., 1.) # calculate PSNR using images['samples'] if split == 'val': out = images['samples'] if 'samples' in images.keys() else images['reconstructions'] samples = ((out+1.0)/2.0).mul(255).round() gts = ((batch['image'][:N]+1.0)/2.0).mul(255).round().permute((0,3,1,2)).cpu() mse = torch.mean((samples - gts)**2, dim=1).mean(1).mean(1) psnr = -10 * torch.log10(mse/255**2 + 1e-8) self.val_psnr_epoch.append(psnr) self.log_local(pl_module.logger.save_dir, split, images, pl_module.global_step, pl_module.current_epoch, batch_idx) logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None) logger_log_images(pl_module, images, pl_module.global_step, split) if is_train: pl_module.train() def check_frequency(self, check_idx): if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and ( check_idx > 0 or self.log_first_step): try: self.log_steps.pop(0) except IndexError as e: print(e) pass return True return False def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0): if not self.disabled and pl_module.global_step > 0: self.log_img(pl_module, batch, batch_idx, split="val") if hasattr(pl_module, 'calibrate_grad_norm'): if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0: self.log_gradients(trainer, pl_module, batch_idx=batch_idx) def on_validation_epoch_end(self, trainer, pl_module): if len(self.val_psnr_epoch) > 0: epoch_psnr = torch.cat(self.val_psnr_epoch).mean().item() pl_module.log_dict({'val/psnr':epoch_psnr}, prog_bar=False, logger=True, on_step=False, on_epoch=True) self.val_psnr_epoch = [] class CUDACallback(Callback): # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py def on_train_epoch_start(self, trainer, pl_module): # Reset the memory use counter torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) torch.cuda.synchronize(trainer.strategy.root_device.index) self.start_time = time.time() def on_train_epoch_end(self, trainer, pl_module, outputs=None): torch.cuda.synchronize(trainer.strategy.root_device.index) max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2 ** 20 epoch_time = time.time() - self.start_time try: max_memory = trainer.strategy.reduce(max_memory) epoch_time = trainer.strategy.reduce(epoch_time) rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") except AttributeError: pass if __name__ == "__main__": now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # add cwd for convenience and to make classes in this file available when # running as `python main.py` # (in particular `main.DataModuleFromConfig`) sys.path.append(os.getcwd()) parser = get_parser() parser = Trainer.add_argparse_args(parser) opt, unknown = parser.parse_known_args() if opt.name and opt.resume: raise ValueError( "-n/--name and -r/--resume cannot be specified both." "If you want to resume training in a new log folder, " "use -n/--name in combination with --resume_from_checkpoint" ) if opt.resume: if not os.path.exists(opt.resume): raise ValueError("Cannot find {}".format(opt.resume)) if os.path.isfile(opt.resume): paths = opt.resume.split("/") # idx = len(paths)-paths[::-1].index("logs")+1 # logdir = "/".join(paths[:idx]) logdir = "/".join(paths[:-2]) ckpt = opt.resume else: assert os.path.isdir(opt.resume), opt.resume logdir = opt.resume.rstrip("/") ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") opt.resume_from_checkpoint = ckpt base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) opt.base = base_configs + opt.base _tmp = logdir.split("/") nowname = _tmp[-1] else: if opt.name: name = "_" + opt.name elif opt.base: cfg_fname = os.path.split(opt.base[0])[-1] cfg_name = os.path.splitext(cfg_fname)[0] name = "_" + cfg_name else: name = "" nowname = now + name + opt.postfix logdir = os.path.join(opt.logdir, nowname) ckptdir = os.path.join(logdir, "checkpoints") cfgdir = os.path.join(logdir, "configs") seed_everything(opt.seed) try: # init and save configs configs = [OmegaConf.load(cfg) for cfg in opt.base] cli = OmegaConf.from_dotlist(unknown) config = OmegaConf.merge(*configs, cli) lightning_config = config.pop("lightning", OmegaConf.create()) # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp # trainer_config["accelerator"] = "ddp" # TODO: confirm why defaulting to ddp doesn't work trainer_config["accelerator"] = "gpu" for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not "gpus" in trainer_config: del trainer_config["accelerator"] cpu = True else: gpuinfo = trainer_config["gpus"] print(f"Running on GPUs {gpuinfo}") cpu = False trainer_config['devices'] = trainer_config.pop('gpus') trainer_opt = argparse.Namespace(**trainer_config) lightning_config.trainer = trainer_config # model model = instantiate_from_config(config.model) # trainer and callbacks trainer_kwargs = dict() # default logger configs default_logger_cfgs = { "wandb": { "target": "pytorch_lightning.loggers.WandbLogger", "params": { "name": nowname, "save_dir": logdir, "offline": opt.debug, "id": nowname, } }, "testtube": { "target": "pytorch_lightning.loggers.TensorBoardLogger", "params": { "name": "testtube", "save_dir": logdir, } }, } default_logger_cfg = default_logger_cfgs["testtube"] if "logger" in lightning_config: logger_cfg = lightning_config.logger else: logger_cfg = OmegaConf.create() logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to # specify which metric is used to determine best models default_modelckpt_cfg = { "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": ckptdir, "filename": "{epoch:06}", "verbose": True, "save_last": True, } } if hasattr(model, "monitor"): print(f"Monitoring {model.monitor} as checkpoint metric.") default_modelckpt_cfg["params"]["monitor"] = model.monitor default_modelckpt_cfg["params"]["save_top_k"] = 3 if "modelcheckpoint" in lightning_config: modelckpt_cfg = lightning_config.modelcheckpoint else: modelckpt_cfg = OmegaConf.create() modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") if version.parse(pl.__version__) < version.parse('1.4.0'): trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg) # add callback which sets up log directory default_callbacks_cfg = { "setup_callback": { "target": "main.SetupCallback", "params": { "resume": opt.resume, "now": now, "logdir": logdir, "ckptdir": ckptdir, "cfgdir": cfgdir, "config": config, "lightning_config": lightning_config, } }, "image_logger": { "target": "main.ImageLogger", "params": { "batch_frequency": 750, "max_images": 4, "clamp": True } }, "learning_rate_logger": { "target": "main.LearningRateMonitor", "params": { "logging_interval": "step", # "log_momentum": True } }, "cuda_callback": { "target": "main.CUDACallback" }, "progress_bar": { "target": 'pytorch_lightning.callbacks.TQDMProgressBar', "params": { 'refresh_rate': 100, } }, } if version.parse(pl.__version__) >= version.parse('1.4.0'): default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg}) if "callbacks" in lightning_config: callbacks_cfg = lightning_config.callbacks else: callbacks_cfg = OmegaConf.create() if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg: print( 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.') default_metrics_over_trainsteps_ckpt_dict = { 'metrics_over_trainsteps_checkpoint': {"target": 'pytorch_lightning.callbacks.ModelCheckpoint', 'params': { "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'), "filename": "{epoch:06}-{step:09}", "verbose": True, 'save_top_k': -1, 'every_n_train_steps': 10000, 'save_weights_only': True } } } default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'): callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint elif 'ignore_keys_callback' in callbacks_cfg: del callbacks_cfg['ignore_keys_callback'] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer.logdir = logdir ### # data data = instantiate_from_config(config.data) # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html # calling these ourselves should not be necessary but it is. # lightning still takes care of proper multiprocessing though data.prepare_data() data.setup() print("#### Data #####") for k in data.datasets: print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}") # configure learning rate bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate if not cpu: ngpu = len(lightning_config.trainer.devices.strip(",").split(',')) else: ngpu = 1 if 'accumulate_grad_batches' in lightning_config.trainer: accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches else: accumulate_grad_batches = 1 print(f"accumulate_grad_batches = {accumulate_grad_batches}") lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches if opt.scale_lr: model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr print( "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) else: model.learning_rate = base_lr print("++++ NOT USING LR SCALING ++++") print(f"Setting learning rate to {model.learning_rate:.2e}") # allow checkpointing via USR1 def melk(*args, **kwargs): # run all checkpoint hooks if trainer.global_rank == 0: print("Summoning checkpoint.") ckpt_path = os.path.join(ckptdir, "last.ckpt") trainer.save_checkpoint(ckpt_path) def divein(*args, **kwargs): if trainer.global_rank == 0: import pudb; pudb.set_trace() import signal import platform # see https://github.com/rinongal/textual_inversion/issues/44 if platform.system() == 'Windows': os.environ["PL_TORCH_DISTRIBUTED_BACKEND"] = "gloo" signal.signal(signal.SIGTERM, melk) signal.signal(signal.SIGTERM, divein) else: signal.signal(signal.SIGUSR1, melk) signal.signal(signal.SIGUSR2, divein) # run if opt.train: try: trainer.fit(model, data) except Exception: melk() raise if not opt.no_test and not trainer.interrupted: trainer.test(model, data) except Exception: if opt.debug and trainer.global_rank == 0: try: import pudb as debugger except ImportError: import pdb as debugger debugger.post_mortem() raise finally: # move newly created debug project to debug_runs if opt.debug and not opt.resume and trainer.global_rank == 0: dst, name = os.path.split(logdir) dst = os.path.join(dst, "debug_runs", name) os.makedirs(os.path.split(dst)[0], exist_ok=True) os.rename(logdir, dst) if trainer.global_rank == 0: print(trainer.profiler.summary()) ================================================ FILE: metrics/flolpips/.gitignore ================================================ *__pycache__* *.ipynb *delete* ================================================ FILE: metrics/flolpips/LICENSE ================================================ MIT License Copyright (c) 2022 danielism97 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: metrics/flolpips/README.md ================================================ # FloLPIPS: A bespoke video quality metric for frame interpoation ### Duolikun Danier, Fan Zhang, David Bull [Project](https://danielism97.github.io/FloLPIPS) | [arXiv](https://arxiv.org/abs/2207.08119) ## Dependencies The following packages were used to evaluate the model. - python==3.8.8 - pytorch==1.7.1 - torchvision==0.8.2 - cudatoolkit==10.1.243 - opencv-python==4.5.1.48 - numpy==1.19.2 - pillow==8.1.2 - cupy==9.0.0 ## Usage ```python from flolpips import calc_flolpips ref_video = '.mp4' dis_video = '.mp4' res = calc_flolpips(dis_video, ref_video) ``` ## Citation ``` @article{danier2022flolpips, title={FloLPIPS: A Bespoke Video Quality Metric for Frame Interpoation}, author={Danier, Duolikun and Zhang, Fan and Bull, David}, journal={arXiv preprint arXiv:2207.08119}, year={2022} } ``` ## Acknowledgement Lots of code in this repository are adapted/taken from the following repositories: - [LPIPS](https://github.com/richzhang/PerceptualSimilarity) - [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc) We would like to thank the authors for sharing their code. ================================================ FILE: metrics/flolpips/__init__.py ================================================ from .flolpips import * ================================================ FILE: metrics/flolpips/correlation/correlation.py ================================================ #!/usr/bin/env python import torch import cupy import re kernel_Correlation_rearrange = ''' extern "C" __global__ void kernel_Correlation_rearrange( const int n, const float* input, float* output ) { int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; if (intIndex >= n) { return; } int intSample = blockIdx.z; int intChannel = blockIdx.y; float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; __syncthreads(); int intPaddedY = (intIndex / SIZE_3(input)) + 4; int intPaddedX = (intIndex % SIZE_3(input)) + 4; int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; } ''' kernel_Correlation_updateOutput = ''' extern "C" __global__ void kernel_Correlation_updateOutput( const int n, const float* rbot0, const float* rbot1, float* top ) { extern __shared__ char patch_data_char[]; float *patch_data = (float *)patch_data_char; // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 int x1 = blockIdx.x + 4; int y1 = blockIdx.y + 4; int item = blockIdx.z; int ch_off = threadIdx.x; // Load 3D patch into shared shared memory for (int j = 0; j < 1; j++) { // HEIGHT for (int i = 0; i < 1; i++) { // WIDTH int ji_off = (j + i) * SIZE_3(rbot0); for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; int idxPatchData = ji_off + ch; patch_data[idxPatchData] = rbot0[idx1]; } } } __syncthreads(); __shared__ float sum[32]; // Compute correlation for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { sum[ch_off] = 0; int s2o = top_channel % 9 - 4; int s2p = top_channel / 9 - 4; for (int j = 0; j < 1; j++) { // HEIGHT for (int i = 0; i < 1; i++) { // WIDTH int ji_off = (j + i) * SIZE_3(rbot0); for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS int x2 = x1 + s2o; int y2 = y1 + s2p; int idxPatchData = ji_off + ch; int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; } } } __syncthreads(); if (ch_off == 0) { float total_sum = 0; for (int idx = 0; idx < 32; idx++) { total_sum += sum[idx]; } const int sumelems = SIZE_3(rbot0); const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; } } } ''' kernel_Correlation_updateGradFirst = ''' #define ROUND_OFF 50000 extern "C" __global__ void kernel_Correlation_updateGradFirst( const int n, const int intSample, const float* rbot0, const float* rbot1, const float* gradOutput, float* gradFirst, float* gradSecond ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { int n = intIndex % SIZE_1(gradFirst); // channels int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos // round_off is a trick to enable integer division with ceil, even for negative numbers // We use a large offset, for the inner part not to become negative. const int round_off = ROUND_OFF; const int round_off_s1 = round_off; // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) // Same here: int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) float sum = 0; if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { xmin = max(0,xmin); xmax = min(SIZE_3(gradOutput)-1,xmax); ymin = max(0,ymin); ymax = min(SIZE_2(gradOutput)-1,ymax); for (int p = -4; p <= 4; p++) { for (int o = -4; o <= 4; o++) { // Get rbot1 data: int s2o = o; int s2p = p; int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] // Index offset for gradOutput in following loops: int op = (p+4) * 9 + (o+4); // index[o,p] int idxopoffset = (intSample * SIZE_1(gradOutput) + op); for (int y = ymin; y <= ymax; y++) { for (int x = xmin; x <= xmax; x++) { int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] sum += gradOutput[idxgradOutput] * bot1tmp; } } } } } const int sumelems = SIZE_1(gradFirst); const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; } } ''' kernel_Correlation_updateGradSecond = ''' #define ROUND_OFF 50000 extern "C" __global__ void kernel_Correlation_updateGradSecond( const int n, const int intSample, const float* rbot0, const float* rbot1, const float* gradOutput, float* gradFirst, float* gradSecond ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { int n = intIndex % SIZE_1(gradSecond); // channels int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos // round_off is a trick to enable integer division with ceil, even for negative numbers // We use a large offset, for the inner part not to become negative. const int round_off = ROUND_OFF; const int round_off_s1 = round_off; float sum = 0; for (int p = -4; p <= 4; p++) { for (int o = -4; o <= 4; o++) { int s2o = o; int s2p = p; //Get X,Y ranges and clamp // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) // Same here: int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { xmin = max(0,xmin); xmax = min(SIZE_3(gradOutput)-1,xmax); ymin = max(0,ymin); ymax = min(SIZE_2(gradOutput)-1,ymax); // Get rbot0 data: int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] // Index offset for gradOutput in following loops: int op = (p+4) * 9 + (o+4); // index[o,p] int idxopoffset = (intSample * SIZE_1(gradOutput) + op); for (int y = ymin; y <= ymax; y++) { for (int x = xmin; x <= xmax; x++) { int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] sum += gradOutput[idxgradOutput] * bot0tmp; } } } } } const int sumelems = SIZE_1(gradSecond); const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; } } ''' def cupy_kernel(strFunction, objVariables): strKernel = globals()[strFunction] while True: objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objMatch is None: break # end intArg = int(objMatch.group(2)) strTensor = objMatch.group(4) intSizes = objVariables[strTensor].size() strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) # end while True: objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) if objMatch is None: break # end intArgs = int(objMatch.group(2)) strArgs = objMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') # end return strKernel # end @cupy.memoize(for_each_device=True) def cupy_launch(strFunction, strKernel): return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) # end class _FunctionCorrelation(torch.autograd.Function): @staticmethod def forward(self, first, second): rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) self.save_for_backward(first, second, rbot0, rbot1) first = first.contiguous(); assert(first.is_cuda == True) second = second.contiguous(); assert(second.is_cuda == True) output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) if first.is_cuda == True: n = first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 'input': first, 'output': rbot0 }))( grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), block=tuple([ 16, 1, 1 ]), args=[ n, first.data_ptr(), rbot0.data_ptr() ] ) n = second.shape[2] * second.shape[3] cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 'input': second, 'output': rbot1 }))( grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), block=tuple([ 16, 1, 1 ]), args=[ n, second.data_ptr(), rbot1.data_ptr() ] ) n = output.shape[1] * output.shape[2] * output.shape[3] cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 'rbot0': rbot0, 'rbot1': rbot1, 'top': output }))( grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), block=tuple([ 32, 1, 1 ]), shared_mem=first.shape[1] * 4, args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] ) elif first.is_cuda == False: raise NotImplementedError() # end return output # end @staticmethod def backward(self, gradOutput): first, second, rbot0, rbot1 = self.saved_tensors gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None if first.is_cuda == True: if gradFirst is not None: for intSample in range(first.shape[0]): n = first.shape[1] * first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { 'rbot0': rbot0, 'rbot1': rbot1, 'gradOutput': gradOutput, 'gradFirst': gradFirst, 'gradSecond': None }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] ) # end # end if gradSecond is not None: for intSample in range(first.shape[0]): n = first.shape[1] * first.shape[2] * first.shape[3] cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { 'rbot0': rbot0, 'rbot1': rbot1, 'gradOutput': gradOutput, 'gradFirst': None, 'gradSecond': gradSecond }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] ) # end # end elif first.is_cuda == False: raise NotImplementedError() # end return gradFirst, gradSecond # end # end def FunctionCorrelation(tenFirst, tenSecond): return _FunctionCorrelation.apply(tenFirst, tenSecond) # end class ModuleCorrelation(torch.nn.Module): def __init__(self): super(ModuleCorrelation, self).__init__() # end def forward(self, tenFirst, tenSecond): return _FunctionCorrelation.apply(tenFirst, tenSecond) # end # end ================================================ FILE: metrics/flolpips/flolpips.py ================================================ from __future__ import absolute_import import os import numpy as np import torch import torch.nn as nn from torch.autograd import Variable import metrics.flolpips.pretrained_networks as pn import torch.nn import torch.nn.functional as F import torchvision.transforms.functional as TF import cv2 from .pwcnet import Network as PWCNet import metrics.flolpips.utils as utils def spatial_average(in_tens, keepdim=True): return in_tens.mean([2,3],keepdim=keepdim) def mw_spatial_average(in_tens, flow, keepdim=True): _,_,h,w = in_tens.shape flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True) return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) def mtw_spatial_average(in_tens, flow, texture, keepdim=True): _,_,h,w = in_tens.shape flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear') flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6 texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6 weight = flow_mag / texture weight /= torch.sum(weight) return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim) def m2w_spatial_average(in_tens, flow, keepdim=True): _,_,h,w = in_tens.shape flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W flow_mag = flow_mag / torch.sum(flow_mag) return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): # lpips - [True] means with linear calibration on top of base network # pretrained - [True] means load linear weights super(LPIPS, self).__init__() if(verbose): print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if(self.pnet_type in ['vgg','vgg16']): net_type = pn.vgg16 self.chns = [64,128,256,512,512] elif(self.pnet_type=='alex'): net_type = pn.alexnet self.chns = [64,192,384,256,256] elif(self.pnet_type=='squeeze'): net_type = pn.squeezenet self.chns = [64,128,256,384,384,512,512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if(lpips): self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] if(self.pnet_type=='squeeze'): # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins+=[self.lin5,self.lin6] self.lins = nn.ModuleList(self.lins) if(pretrained): if(model_path is None): import inspect import os model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) if(verbose): print('Loading model from: %s'%model_path) self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) if(eval_mode): self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = utils.normalize_tensor(outs0[kk]), utils.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] # val = res[0] # for l in range(1,self.L): # val += res[l] # print(val) # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(self.lins[kk](feats0[kk]**2)) # for kk in range(self.L): # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) # a = a/self.L # from IPython import embed # embed() # return 10*torch.log10(b/a) # if(retPerLayer): # return (val, res) # else: return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(),] if(use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] if(use_sigmoid): layers += [nn.Sigmoid(),] self.model = nn.Sequential(*layers) def forward(self,d0,d1,eps=0.1): return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge+1.)/2. self.logit = self.net.forward(d0,d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace='Lab'): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): (N,C,X,Y) = in0.size() value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) return value elif(self.colorspace=='Lab'): value = utils.l2(utils.tensor2np(utils.tensor2tensorlab(in0.data,to_norm=False)), utils.tensor2np(utils.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): value = utils.dssim(1.*utils.tensor2im(in0.data), 1.*utils.tensor2im(in1.data), range=255.).astype('float') elif(self.colorspace=='Lab'): value = utils.dssim(utils.tensor2np(utils.tensor2tensorlab(in0.data,to_norm=False)), utils.tensor2np(utils.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network',net) print('Total number of parameters: %d' % num_params) class FloLPIPS(LPIPS): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose) def forward(self, in0, in1, flow, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = utils.normalize_tensor(outs0[kk]), utils.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)] return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) def calc_flolpips(dis_path, ref_path): batch_size = 8 # convert to yuv first os.system('ffmpeg -hide_banner -loglevel error -i {} flolpips_ref.yuv'.format(ref_path)) os.system('ffmpeg -hide_banner -loglevel error -i {} flolpips_dis.yuv'.format(dis_path)) loss_fn = FloLPIPS(net='alex',version='0.1').cuda() flownet = PWCNet().cuda() # batch_size = 128 cap_dis = cv2.VideoCapture(dis_path) cap_ref = cv2.VideoCapture(ref_path) assert int(cap_dis.get(cv2.CAP_PROP_FRAME_COUNT)) == int(cap_ref.get(cv2.CAP_PROP_FRAME_COUNT)) num_frames = int(cap_ref.get(cv2.CAP_PROP_FRAME_COUNT)) width = int(cap_ref.get(3)) height = int(cap_ref.get(4)) cap_dis.release() cap_ref.release() stream_dis = open('flolpips_dis.yuv', 'r') stream_ref = open('flolpips_ref.yuv', 'r') flolpips_list = [] batch_ref_list, batch_dis_list = [], [] batch_ref_next_list, batch_dis_next_list = [], [] for iFrame in range(num_frames-1): frame_dis = TF.to_tensor(utils.read_frame_yuv2rgb(stream_dis, width, height, iFrame, 8, '420')) frame_dis_next = TF.to_tensor(utils.read_frame_yuv2rgb(stream_dis, width, height, iFrame+1, 8, '420')) frame_ref = TF.to_tensor(utils.read_frame_yuv2rgb(stream_ref, width, height, iFrame, 8, '420')) frame_ref_next = TF.to_tensor(utils.read_frame_yuv2rgb(stream_ref, width, height, iFrame+1, 8, '420')) batch_dis_list.append(frame_dis) batch_dis_next_list.append(frame_dis_next) batch_ref_list.append(frame_ref) batch_ref_next_list.append(frame_ref_next) if len(batch_ref_list) % batch_size == 0: with torch.no_grad(): frames_ref = torch.stack(batch_ref_list, dim=0).cuda() frames_dis = torch.stack(batch_dis_list, dim=0).cuda() frames_ref_next = torch.stack(batch_ref_next_list, dim=0).cuda() frames_dis_next = torch.stack(batch_dis_next_list, dim=0).cuda() flow_ref = flownet(frames_ref, frames_ref_next) flow_dis = flownet(frames_dis, frames_dis_next) flow_diff = flow_ref - flow_dis flolpips = loss_fn.forward(frames_ref, frames_dis, flow_diff, normalize=True) batch_ref_list, batch_dis_list, batch_ref_next_list, batch_dis_next_list = [], [], [], [] flolpips_list = flolpips_list + flolpips.cpu().tolist() if len(batch_ref_list) > 0: with torch.no_grad(): frames_ref = torch.stack(batch_ref_list, dim=0).cuda() frames_dis = torch.stack(batch_dis_list, dim=0).cuda() frames_ref_next = torch.stack(batch_ref_next_list, dim=0).cuda() frames_dis_next = torch.stack(batch_dis_next_list, dim=0).cuda() flow_ref = flownet(frames_ref, frames_ref_next) flow_dis = flownet(frames_dis, frames_dis_next) flow_diff = flow_ref - flow_dis flolpips = loss_fn.forward(frames_ref, frames_dis, flow_diff, normalize=True) flolpips_list = flolpips_list + flolpips.cpu().tolist() stream_dis.close() stream_ref.close() # delete files, modify command accordingly os.remove('flolpips_dis.yuv') os.remove('flolpips_ref.yuv') return np.mean(flolpips_list) ================================================ FILE: metrics/flolpips/pretrained_networks.py ================================================ from collections import namedtuple import torch from torchvision import models as tv class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(squeezenet, self).__init__() pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() self.slice7 = torch.nn.Sequential() self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) for x in range(2,5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), pretrained_features[x]) for x in range(10, 11): self.slice5.add_module(str(x), pretrained_features[x]) for x in range(11, 12): self.slice6.add_module(str(x), pretrained_features[x]) for x in range(12, 13): self.slice7.add_module(str(x), pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h h = self.slice6(h) h_relu6 = h h = self.slice7(h) h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(alexnet, self).__init__() alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): super(resnet, self).__init__() if(num==18): self.net = tv.resnet18(pretrained=pretrained) elif(num==34): self.net = tv.resnet34(pretrained=pretrained) elif(num==50): self.net = tv.resnet50(pretrained=pretrained) elif(num==101): self.net = tv.resnet101(pretrained=pretrained) elif(num==152): self.net = tv.resnet152(pretrained=pretrained) self.N_slices = 5 self.conv1 = self.net.conv1 self.bn1 = self.net.bn1 self.relu = self.net.relu self.maxpool = self.net.maxpool self.layer1 = self.net.layer1 self.layer2 = self.net.layer2 self.layer3 = self.net.layer3 self.layer4 = self.net.layer4 def forward(self, X): h = self.conv1(X) h = self.bn1(h) h = self.relu(h) h_relu1 = h h = self.maxpool(h) h = self.layer1(h) h_conv2 = h h = self.layer2(h) h_conv3 = h h = self.layer3(h) h_conv4 = h h = self.layer4(h) h_conv5 = h outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out ================================================ FILE: metrics/flolpips/pwcnet.py ================================================ #!/usr/bin/env python import torch import getopt import math import numpy import os import PIL import PIL.Image import sys # try: from metrics.flolpips.correlation import correlation # the custom cost volume layer # except: # sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python # end ########################################################## # assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 # torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance # torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance # ########################################################## # arguments_strModel = 'default' # 'default', or 'chairs-things' # arguments_strFirst = './images/first.png' # arguments_strSecond = './images/second.png' # arguments_strOut = './out.flo' # for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: # if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use # if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame # if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame # if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored # end ########################################################## def backwarp(tenInput, tenFlow): backwarp_tenGrid = {} backwarp_tenPartial = {} if str(tenFlow.shape) not in backwarp_tenGrid: tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() # end if str(tenFlow.shape) not in backwarp_tenPartial: backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) # end tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1) tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 return tenOutput[:, :-1, :, :] * tenMask # end ########################################################## class Network(torch.nn.Module): def __init__(self): super(Network, self).__init__() class Extractor(torch.nn.Module): def __init__(self): super(Extractor, self).__init__() self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) # end def forward(self, tenInput): tenOne = self.netOne(tenInput) tenTwo = self.netTwo(tenOne) tenThr = self.netThr(tenTwo) tenFou = self.netFou(tenThr) tenFiv = self.netFiv(tenFou) tenSix = self.netSix(tenFiv) return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] # end # end class Decoder(torch.nn.Module): def __init__(self, intLevel): super(Decoder, self).__init__() intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] self.netOne = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netTwo = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netThr = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFou = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netFiv = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.netSix = torch.nn.Sequential( torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) ) # end def forward(self, tenFirst, tenSecond, objPrevious): tenFlow = None tenFeat = None if objPrevious is None: tenFlow = None tenFeat = None tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume ], 1) elif objPrevious is not None: tenFlow = self.netUpflow(objPrevious['tenFlow']) tenFeat = self.netUpfeat(objPrevious['tenFeat']) tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) # end tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) tenFlow = self.netSix(tenFeat) return { 'tenFlow': tenFlow, 'tenFeat': tenFeat } # end # end class Refiner(torch.nn.Module): def __init__(self): super(Refiner, self).__init__() self.netMain = torch.nn.Sequential( torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) ) # end def forward(self, tenInput): return self.netMain(tenInput) # end # end self.netExtractor = Extractor() self.netTwo = Decoder(2) self.netThr = Decoder(3) self.netFou = Decoder(4) self.netFiv = Decoder(5) self.netSix = Decoder(6) self.netRefiner = Refiner() self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() }) # end def forward(self, tenFirst, tenSecond, *args): intWidth = tenFirst.shape[3] intHeight = tenFirst.shape[2] intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) # optionally pass pre-extracted feature pyramid in as args if len(args) == 0: tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenFirst = self.netExtractor(tenPreprocessedFirst) tenSecond = self.netExtractor(tenPreprocessedSecond) else: tenFirst, tenSecond = args objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False) tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) return tenFlow # end # end def extract_pyramid_single(self, tenFirst): intWidth = tenFirst.shape[3] intHeight = tenFirst.shape[2] intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) return self.netExtractor(tenPreprocessedFirst) netNetwork = None ########################################################## def estimate(tenFirst, tenSecond): global netNetwork if netNetwork is None: netNetwork = Network().cuda().eval() # end assert(tenFirst.shape[1] == tenSecond.shape[1]) assert(tenFirst.shape[2] == tenSecond.shape[2]) intWidth = tenFirst.shape[2] intHeight = tenFirst.shape[1] assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) return tenFlow[0, :, :, :].cpu() # end ########################################################## # if __name__ == '__main__': # tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) # tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) # tenOutput = estimate(tenFirst, tenSecond) # objOutput = open(arguments_strOut, 'wb') # numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) # numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) # numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) # objOutput.close() # end ================================================ FILE: metrics/flolpips/utils.py ================================================ import numpy as np import cv2 import torch def normalize_tensor(in_feat,eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) return in_feat/(norm_factor+eps) def l2(p0, p1, range=255.): return .5*np.mean((p0 / range - p1 / range)**2) def dssim(p0, p1, range=255.): from skimage.measure import compare_ssim return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def tensor2np(tensor_obj): # change dimension of a tensor object into a numpy array return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) def np2tensor(np_obj): # change dimenion of np array into tensor array return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): # image tensor to lab tensor from skimage import color img = tensor2im(image_tensor) img_lab = color.rgb2lab(img) if(mc_only): img_lab[:,:,0] = img_lab[:,:,0]-50 if(to_norm and not mc_only): img_lab[:,:,0] = img_lab[:,:,0]-50 img_lab = img_lab/100. return np2tensor(img_lab) def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'): if pix_fmt == '420': multiplier = 1 uv_factor = 2 elif pix_fmt == '444': multiplier = 2 uv_factor = 1 else: print('Pixel format {} is not supported'.format(pix_fmt)) return if bit_depth == 8: datatype = np.uint8 stream.seek(iFrame*1.5*width*height*multiplier) Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) # read chroma samples and upsample since original is 4:2:0 sampling U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) else: datatype = np.uint16 stream.seek(iFrame*3*width*height*multiplier) Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) if pix_fmt == '420': yuv = np.empty((height*3//2, width), dtype=datatype) yuv[0:height,:] = Y yuv[height:height+height//4,:] = U.reshape(-1, width) yuv[height+height//4:,:] = V.reshape(-1, width) if bit_depth != 8: yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8) #convert to rgb rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) else: yvu = np.stack([Y,V,U],axis=2) if bit_depth != 8: yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8) rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) return rgb ================================================ FILE: metrics/lpips/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import torch # from torch.autograd import Variable # from .trainer import * from .lpips import * # class PerceptualLoss(torch.nn.Module): # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss # super(PerceptualLoss, self).__init__() # print('Setting up Perceptual loss...') # self.use_gpu = use_gpu # self.spatial = spatial # self.gpu_ids = gpu_ids # self.model = dist_model.DistModel() # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) # print('...[%s] initialized'%self.model.name()) # print('...Done') # def forward(self, pred, target, normalize=False): # """ # Pred and target are Variables. # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] # If normalize is False, assumes the images are already between [-1,+1] # Inputs pred and target are Nx3xHxW # Output pytorch Variable N long # """ # if normalize: # target = 2 * target - 1 # pred = 2 * pred - 1 # return self.model.forward(target, pred) def normalize_tensor(in_feat,eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) return in_feat/(norm_factor+eps) def l2(p0, p1, range=255.): return .5*np.mean((p0 / range - p1 / range)**2) def psnr(p0, p1, peak=255.): return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) def dssim(p0, p1, range=255.): from skimage.measure import compare_ssim return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. def rgb2lab(in_img,mean_cent=False): from skimage import color img_lab = color.rgb2lab(in_img) if(mean_cent): img_lab[:,:,0] = img_lab[:,:,0]-50 return img_lab def tensor2np(tensor_obj): # change dimension of a tensor object into a numpy array return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) def np2tensor(np_obj): # change dimenion of np array into tensor array return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): # image tensor to lab tensor from skimage import color img = tensor2im(image_tensor) img_lab = color.rgb2lab(img) if(mc_only): img_lab[:,:,0] = img_lab[:,:,0]-50 if(to_norm and not mc_only): img_lab[:,:,0] = img_lab[:,:,0]-50 img_lab = img_lab/100. return np2tensor(img_lab) def tensorlab2tensor(lab_tensor,return_inbnd=False): from skimage import color import warnings warnings.filterwarnings("ignore") lab = tensor2np(lab_tensor)*100. lab[:,:,0] = lab[:,:,0]+50 rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) if(return_inbnd): # convert back to lab, see if we match lab_back = color.rgb2lab(rgb_back.astype('uint8')) mask = 1.*np.isclose(lab_back,lab,atol=2.) mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) return (im2tensor(rgb_back),mask) else: return im2tensor(rgb_back) def load_image(path): if(path[-3:] == 'dng'): import rawpy with rawpy.imread(path) as raw: img = raw.postprocess() elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'): import cv2 return cv2.imread(path)[:,:,::-1] else: img = (255*plt.imread(path)[:,:,:3]).astype('uint8') return img def rgb2lab(input): from skimage import color return color.rgb2lab(input / 255.) def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): return torch.Tensor((image / factor - cent) [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def tensor2vec(vector_tensor): return vector_tensor.data.cpu().numpy()[:, :, 0, 0] def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): image_numpy = image_tensor[0].cpu().float().numpy() image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor return image_numpy.astype(imtype) def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): return torch.Tensor((image / factor - cent) [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) def voc_ap(rec, prec, use_07_metric=False): """ ap = voc_ap(rec, prec, [use_07_metric]) Compute VOC AP given precision and recall. If use_07_metric is true, uses the VOC 07 11 point method (default:False). """ if use_07_metric: # 11 point metric ap = 0. for t in np.arange(0., 1.1, 0.1): if np.sum(rec >= t) == 0: p = 0 else: p = np.max(prec[rec >= t]) ap = ap + p / 11. else: # correct AP calculation # first append sentinel values at the end mrec = np.concatenate(([0.], rec, [1.])) mpre = np.concatenate(([0.], prec, [0.])) # compute the precision envelope for i in range(mpre.size - 1, 0, -1): mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) # to calculate area under PR curve, look for points # where X axis (recall) changes value i = np.where(mrec[1:] != mrec[:-1])[0] # and sum (\Delta recall) * prec ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) return ap ================================================ FILE: metrics/lpips/lpips.py ================================================ from __future__ import absolute_import import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import numpy as np from . import pretrained_networks as pn import torch.nn from .. import lpips def spatial_average(in_tens, keepdim=True): if torch.__version__ == '0.4.0': if keepdim: return in_tens.mean(2, keepdim=True).mean(3, keepdim=True) else: in_tens.mean(2).mean(2) return in_tens.mean([2,3],keepdim=keepdim) def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): # lpips - [True] means with linear calibration on top of base network # pretrained - [True] means load linear weights super(LPIPS, self).__init__() if(verbose): print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if(self.pnet_type in ['vgg','vgg16']): net_type = pn.vgg16 self.chns = [64,128,256,512,512] elif(self.pnet_type=='alex'): net_type = pn.alexnet self.chns = [64,192,384,256,256] elif(self.pnet_type=='squeeze'): net_type = pn.squeezenet self.chns = [64,128,256,384,384,512,512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if(lpips): self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] if(self.pnet_type=='squeeze'): # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins+=[self.lin5,self.lin6] self.lins = nn.ModuleList(self.lins) if(pretrained): if(model_path is None): import inspect import os model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) if(verbose): print('Loading model from: %s'%model_path) self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) if(eval_mode): self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] # val = res[0] # for l in range(1,self.L): # val += res[l] # print(val) # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(self.lins[kk](feats0[kk]**2)) # for kk in range(self.L): # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) # a = a/self.L # from IPython import embed # embed() # return 10*torch.log10(b/a) # if(retPerLayer): # return (val, res) # else: # return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) return torch.cat(res, 1).sum(dim=1).sum(dim=1).sum(dim=1) class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(),] if(use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] if(use_sigmoid): layers += [nn.Sigmoid(),] self.model = nn.Sequential(*layers) def forward(self,d0,d1,eps=0.1): return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge+1.)/2. self.logit = self.net.forward(d0,d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace='Lab'): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace = colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): (N,C,X,Y) = in0.size() value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) return value elif(self.colorspace=='Lab'): value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float') elif(self.colorspace=='Lab'): value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network',net) print('Total number of parameters: %d' % num_params) ================================================ FILE: metrics/lpips/pretrained_networks.py ================================================ from collections import namedtuple import torch from torchvision import models as tv class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(squeezenet, self).__init__() pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() self.slice7 = torch.nn.Sequential() self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) for x in range(2,5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), pretrained_features[x]) for x in range(10, 11): self.slice5.add_module(str(x), pretrained_features[x]) for x in range(11, 12): self.slice6.add_module(str(x), pretrained_features[x]) for x in range(12, 13): self.slice7.add_module(str(x), pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h h = self.slice6(h) h_relu6 = h h = self.slice7(h) h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(alexnet, self).__init__() alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): super(resnet, self).__init__() if(num==18): self.net = tv.resnet18(pretrained=pretrained) elif(num==34): self.net = tv.resnet34(pretrained=pretrained) elif(num==50): self.net = tv.resnet50(pretrained=pretrained) elif(num==101): self.net = tv.resnet101(pretrained=pretrained) elif(num==152): self.net = tv.resnet152(pretrained=pretrained) self.N_slices = 5 self.conv1 = self.net.conv1 self.bn1 = self.net.bn1 self.relu = self.net.relu self.maxpool = self.net.maxpool self.layer1 = self.net.layer1 self.layer2 = self.net.layer2 self.layer3 = self.net.layer3 self.layer4 = self.net.layer4 def forward(self, X): h = self.conv1(X) h = self.bn1(h) h = self.relu(h) h_relu1 = h h = self.maxpool(h) h = self.layer1(h) h_conv2 = h h = self.layer2(h) h_conv3 = h h = self.layer3(h) h_conv4 = h h = self.layer4(h) h_conv5 = h outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out ================================================ FILE: metrics/pytorch_ssim/__init__.py ================================================ import torch import torch.nn.functional as F from torch.autograd import Variable import numpy as np from math import exp def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) return gauss/gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def create_window_3d(window_size, channel=1): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()) _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) window = _3D_window.expand(1, channel, window_size, window_size, window_size).contiguous().cuda() return window def _ssim(img1, img2, window, window_size, channel, size_average = True): mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1*mu2 sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 C1 = 0.01**2 C2 = 0.03**2 ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=255): # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). if val_range is None: if torch.max(img1) > 128: max_val = 255 else: max_val = 1 if torch.min(img1) < -0.5: min_val = -1 else: min_val = 0 L = max_val - min_val else: L = val_range padd = 0 (_, _, height, width) = img1.size() if window is None: real_size = min(window_size, height, width) window = create_window_3d(real_size, channel=1).to(img1.device) # Channel is set to 1 since we consider color images as volumetric images img1 = img1.unsqueeze(1) img2 = img2.unsqueeze(1) mu1 = F.conv3d(F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) mu2 = F.conv3d(F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), window, padding=padd, groups=1) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv3d(F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_sq sigma2_sq = F.conv3d(F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu2_sq sigma12 = F.conv3d(F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), window, padding=padd, groups=1) - mu1_mu2 C1 = (0.01 * L) ** 2 C2 = (0.03 * L) ** 2 v1 = 2.0 * sigma12 + C2 v2 = sigma1_sq + sigma2_sq + C2 cs = torch.mean(v1 / v2) # contrast sensitivity ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) if size_average: ret = ssim_map.mean() else: ret = ssim_map.mean(1).mean(1).mean(1).mean(1) if full: return ret, cs return ret class SSIM(torch.nn.Module): def __init__(self, window_size = 11, size_average = True): super(SSIM, self).__init__() self.window_size = window_size self.size_average = size_average self.channel = 1 self.window = create_window(window_size, self.channel) def forward(self, img1, img2): (_, channel, _, _) = img1.size() if channel == self.channel and self.window.data.type() == img1.data.type(): window = self.window else: window = create_window(self.window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) self.window = window self.channel = channel return _ssim(img1, img2, window, self.window_size, channel, self.size_average) def ssim(img1, img2, window_size = 11, size_average = True): (_, channel, _, _) = img1.size() window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return _ssim(img1, img2, window, window_size, channel, size_average) ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name='latent-diffusion', version='0.0.1', description='', packages=find_packages(), install_requires=[ 'torch', 'numpy', 'tqdm', ], ) ================================================ FILE: utility.py ================================================ import torch import torch.nn.functional as F import numpy as np import cv2 from metrics import pytorch_ssim, lpips, flolpips def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'): if pix_fmt == '420': multiplier = 1 uv_factor = 2 elif pix_fmt == '444': multiplier = 2 uv_factor = 1 else: print('Pixel format {} is not supported'.format(pix_fmt)) return if bit_depth == 8: datatype = np.uint8 stream.seek(iFrame*1.5*width*height*multiplier) Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) # read chroma samples and upsample since original is 4:2:0 sampling U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) else: datatype = np.uint16 stream.seek(iFrame*3*width*height*multiplier) Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ reshape((height//uv_factor, width//uv_factor)) if pix_fmt == '420': yuv = np.empty((height*3//2, width), dtype=datatype) yuv[0:height,:] = Y yuv[height:height+height//4,:] = U.reshape(-1, width) yuv[height+height//4:,:] = V.reshape(-1, width) if bit_depth != 8: yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8) #convert to rgb rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) else: yvu = np.stack([Y,V,U],axis=2) if bit_depth != 8: yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8) rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) return rgb def CharbonnierFunc(data, epsilon=0.001): return torch.mean(torch.sqrt(data ** 2 + epsilon ** 2)) def moduleNormalize(frame): return torch.cat([(frame[:, 0:1, :, :] - 0.4631), (frame[:, 1:2, :, :] - 0.4352), (frame[:, 2:3, :, :] - 0.3990)], 1) def gaussian_kernel(sz, sigma): k = torch.arange(-(sz-1)/2, (sz+1)/2) k = torch.exp(-1.0/(2*sigma**2) * k**2) k = k.reshape(-1, 1) * k.reshape(1, -1) k = k / torch.sum(k) return k def quantize(imTensor): return ((imTensor.clamp(-1.0, 1.0)+1.)/2.).mul(255).round() def tensor2rgb(tensor): """ Convert GPU Tensor to RGB image (numpy array) """ out = [] for b in range(tensor.shape[0]): out.append(np.moveaxis(quantize(tensor[b]).cpu().detach().numpy(), 0, 2).astype(np.uint8)) return np.array(out) #(B,H,W,C) def calc_psnr(gt, out, *args): """ args: gt, out -- (B,3,H,W) cuda Tensors in [-1,1] """ mse = torch.mean((quantize(gt) - quantize(out))**2, dim=1).mean(1).mean(1) return -10 * torch.log10(mse/255**2 + 1e-8) # (B,) def calc_ssim(gt, out, *args): return pytorch_ssim.ssim_matlab(quantize(gt), quantize(out), size_average=False) def calc_lpips(gt, out, *args): loss_fn = lpips.LPIPS(net='alex',version='0.1').cuda() # return loss_fn.forward(gt, out, normalize=True) return loss_fn.forward(quantize(gt)/255., quantize(out)/255., normalize=True) def calc_flolpips(gt_list, out_list, inputs_list): ''' gt, out - list of (B,3,H,W) cuda Tensors in [-1,1] inputs - list of two (B,3,H,W) cuda Tensors in [-1,1] e.g. gt can contain frames 1,3,5... while inputs contains frames 0,2,4,6... ''' loss_fn = flolpips.FloLPIPS(net='alex',version='0.1').cuda() flownet = flolpips.PWCNet().cuda() scores = [] for i in range(len(gt_list)): frame_ref = (gt_list[i] + 1.) / 2. frame_dis = (out_list[i] + 1.) / 2. frame_prev = (inputs_list[i] + 1.) / 2. if i == 0 else frame_next frame_next = (inputs_list[i+1] + 1.) / 2. with torch.no_grad(): feat_ref = flownet.extract_pyramid_single(frame_ref) feat_dis = flownet.extract_pyramid_single(frame_dis) feat_prev = flownet.extract_pyramid_single(frame_prev) if i == 0 else feat_next feat_next = flownet.extract_pyramid_single(frame_next) # for first two frames in triplet flow_ref = flownet(frame_ref, frame_next, feat_ref, feat_next) flow_dis = flownet(frame_dis, frame_next, feat_dis, feat_next) flow_diff = flow_ref - flow_dis scores.append(loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True).item()) # for next two frames in triplet flow_ref = flownet(frame_ref, frame_prev, feat_ref, feat_prev) flow_dis = flownet(frame_dis, frame_prev, feat_dis, feat_prev) flow_diff = flow_ref - flow_dis scores.append(loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True).item()) return np.mean(scores)