main df08bb76d051 cached
100 files
687.8 KB
188.1k tokens
741 symbols
1 requests
Download .txt
Showing preview only (724K chars total). Download the full file or copy to clipboard to get everything.
Repository: fudan-generative-vision/DicFace
Branch: main
Commit: df08bb76d051
Files: 100
Total size: 687.8 KB

Directory structure:
gitextract_tgnuarg3/

├── .gitignore
├── README.md
├── basicsr/
│   ├── VERSION
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── arcface_arch.py
│   │   ├── arch_util.py
│   │   ├── dir_dist_codeformer_multiscale_arch.py
│   │   ├── rrdbnet_arch.py
│   │   ├── vgg_arch.py
│   │   └── vqgan_arch.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── color_dataset.py
│   │   ├── data_sampler.py
│   │   ├── data_util.py
│   │   ├── degradations.py
│   │   ├── gaussian_kernels.py
│   │   ├── inpainting_dataset.py
│   │   ├── paired_image_dataset.py
│   │   ├── prefetch_dataloader.py
│   │   ├── transforms.py
│   │   └── vfhq_dataset.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── loss_util.py
│   │   └── losses.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── metric_util.py
│   │   └── psnr_ssim.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── codeformer_dirichlet_video_model.py
│   │   ├── lr_scheduler.py
│   │   ├── sr_model.py
│   │   └── vqgan_model.py
│   ├── ops/
│   │   ├── __init__.py
│   │   ├── dcn/
│   │   │   ├── __init__.py
│   │   │   ├── deform_conv.py
│   │   │   └── src/
│   │   │       ├── deform_conv_cuda.cpp
│   │   │       ├── deform_conv_cuda_kernel.cu
│   │   │       └── deform_conv_ext.cpp
│   │   ├── fused_act/
│   │   │   ├── __init__.py
│   │   │   ├── fused_act.py
│   │   │   └── src/
│   │   │       ├── fused_bias_act.cpp
│   │   │       └── fused_bias_act_kernel.cu
│   │   └── upfirdn2d/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── upfirdn2d.cpp
│   │       │   └── upfirdn2d_kernel.cu
│   │       └── upfirdn2d.py
│   ├── setup.py
│   ├── train.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── dist_util.py
│   │   ├── download_util.py
│   │   ├── file_client.py
│   │   ├── img_util.py
│   │   ├── lmdb_util.py
│   │   ├── logger.py
│   │   ├── matlab_functions.py
│   │   ├── misc.py
│   │   ├── options.py
│   │   ├── realesrgan_utils.py
│   │   ├── registry.py
│   │   └── video_util.py
│   └── version.py
├── facelib/
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── align_trans.py
│   │   ├── matlab_cp2tform.py
│   │   ├── retinaface/
│   │   │   ├── retinaface.py
│   │   │   ├── retinaface_net.py
│   │   │   └── retinaface_utils.py
│   │   └── yolov5face/
│   │       ├── __init__.py
│   │       ├── face_detector.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── common.py
│   │       │   ├── experimental.py
│   │       │   ├── yolo.py
│   │       │   ├── yolov5l.yaml
│   │       │   └── yolov5n.yaml
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── autoanchor.py
│   │           ├── datasets.py
│   │           ├── extract_ckpt.py
│   │           ├── general.py
│   │           └── torch_utils.py
│   ├── parsing/
│   │   ├── __init__.py
│   │   ├── bisenet.py
│   │   ├── parsenet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── face_restoration_helper.py
│       ├── face_utils.py
│       └── misc.py
├── options/
│   ├── clip5_bs2_512_align_nofix_multiscale.yaml
│   ├── clip5_bs2_512_align_nofix_multiscale_color.yaml
│   └── clip5_bs2_512_align_nofix_multiscale_inpaint.yaml
├── requirements.txt
├── scripts/
│   ├── inference.py
│   ├── inference_color_and_inpainting.py
│   └── warp_images.py
└── train.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# 忽略操作系统生成的文件
.DS_Store
Thumbs.db

# 忽略编译生成的文件
*.class
*.exe
*.o
*.so
.eggs/
*.egg-info/


# 忽略包管理工具生成的文件
node_modules/
vendor/

# 忽略 Python 缓存目录
__pycache__/

# 忽略日志文件
*.log

# 忽略环境配置文件
.env

# 忽略IDE/编辑器配置文件
.idea/
.vscode/

# test folder
test*/

# ckpts

ckpts/


================================================
FILE: README.md
================================================
<h1 align='center'>DicFace: Dirichlet-Constrained Variational Codebook Learning for Temporally Coherent Video Face Restoration</h1>

<div align='center'>
    <a href='' target='_blank'>Yan Chen</a><sup>1*</sup>&emsp;
    <a href='' target='_blank'>Hanlin Shang</a><sup>1*</sup>&emsp;
    <a href='' target='_blank'>Ce Liu</a><sup>1</sup>&emsp;
    <a href='' target='_blank'>Yuxuan Chen</a><sup>1</sup>&emsp;
    <a href='' target='_blank'>Hui Li</a><sup>1</sup>&emsp;
    <a href='' target='_blank'>Weihao Yuan</a><sup>2</sup>&emsp;
</div>
<div align='center'>
    <a href='' target='_blank'>Hao Zhu</a><sup>3</sup>&emsp;
    <a href='' target='_blank'>Zilong Dong</a><sup>2</sup>&emsp;
    <a href='https://sites.google.com/site/zhusiyucs/home' target='_blank'>Siyu Zhu</a><sup>1✉️</sup>&emsp;
</div>

<div align='center'>
    <sup>1</sup>Fudan University&emsp; 
    <sup>2</sup>Alibaba Group&emsp;
    <sup>3</sup>Nanjing University&emsp;
</div>

<div align='Center'>
<i><strong><a href='https://iccv.thecvf.com/Conferences/2025' target='_blank'>ICCV 2025 Highlight</a></strong></i>
</div>

<br>
<div align='center'>
    <a href='https://github.com/fudan-generative-vision/DicFace'><img src='https://img.shields.io/github/stars/fudan-generative-vision/DicFace'></a>
    <!-- <a href='https://github.com/fudan-generative-vision/DicFace/#/'><img src='https://img.shields.io/badge/Project-HomePage-Green'></a> -->
    <a href='https://arxiv.org/abs/2506.13355'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a>
    <!-- <a href=''><img src='https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Model-yellow'></a> -->
    <!-- <a href='assets/wechat.jpeg'><img src='https://badges.aleen42.com/src/wechat.svg'></a> -->
</div>

<br>

<table align="center" border="0" style="width: 100%; margin-top: 80px;">
  <tr>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/274ecc2b-3d89-4d31-bb0a-a5f3611fae8a" 
             muted autoplay loop style="display: block; margin: 0 auto;"></video>
    </td>
  </tr>
</table>

## 🖼️ Showcase

### Blind Face Restoration
<table align="center" width="100%" border="0" cellpadding="10">
  <tr>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/eb61d793-b860-476e-bae5-f6fcade1e11f" muted autoplay loop width="480"></video>
    </td>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/eb9be43a-8fb9-4fbd-ac92-a686ab0c188b" muted autoplay loop width="480"></video>
    </td>
  </tr>
</table>


### Face Inpainting
<table align="center" width="100%" border="0" cellpadding="10">
  <tr>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/1cd12d53-2ead-4cf3-b56c-1a6316484e93" muted autoplay loop width="480"></video>
    </td>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/a16b7021-a401-41cb-9a39-37a788f6a001" muted autoplay loop width="480"></video>
    </td>
  </tr>
</table>

### Face Colorization
<table align="center" width="100%" border="0" cellpadding="10">
  <tr>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/cb038911-8b26-472d-8fb9-a6cdda127084" muted autoplay loop width="480"></video>
    </td>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/ffc85ef7-4987-42af-b892-79544ea29f87" muted autoplay loop width="480"></video>
    </td>
  </tr>
</table>

### 🐾 Wild Data Examples

<div align="center">

<video src="https://github.com/user-attachments/assets/90fe03dd-b0cc-446b-bb6a-169e98c875df" muted autoplay loop width="3240"></video>
<video src="https://github.com/user-attachments/assets/c165fca5-652b-4586-a928-2ba5bda6ae03" muted autoplay loop width="3240"></video>
<br>
<video src="https://github.com/user-attachments/assets/f911165d-2259-4378-828c-a4468e5fa4dc" muted autoplay loop width="3240"></video>
<br>
<table align="center" width="100%" border="0" cellpadding="10">
  <tr>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/34eea191-f972-4b6f-9529-cc39b9831875" muted autoplay loop width="480"></video>
    </td>
    <td style="text-align: center;">
      <video src="https://github.com/user-attachments/assets/b7f0466b-321d-42b5-ae70-65b4a7347698" muted autoplay loop width="480"></video>
    </td>
  </tr>
</table>
</div>

## 📰 News

- **`2025/07/25`**: 🎉🎉🎉 Our paper has been accepted to [ICCV 2025](https://iccv.thecvf.com/Conferences/2025)and selected as a highlight.
- **`2025/06/26`**: 🎉🎉🎉 Our paper has been accepted to [ICCV 2025](https://iccv.thecvf.com/Conferences/2025).
- **`2025/06/25`**: Release our test data on huggingface [repo](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset).
- **`2025/06/23`**: Release our pretrained model on huggingface [repo](https://huggingface.co/fudan-generative-ai/DicFace).
- **`2025/06/17`**: Paper submitted on Arixiv. [paper](https://arxiv.org/abs/2506.13355)
- **`2025/06/16`**: 🎉🎉🎉 Release inference scripts



## 📅️ Roadmap

| Status | Milestone                                                                                              |    ETA     |
| :----: | :----------------------------------------------------------------------------------------------------- | :--------: |
|   ✅   | **[Inference Code release](https://github.com/fudan-generative-vision/DicFace)**                       |  2025-6-16 |
|   ✅   | **[Model Weight release, baidu-link](https://pan.baidu.com/s/1VTNbdtZDvgY0163a1T8ITw?pwd=dicf)**       |2025-6-16   |
|   ✅   | **[Paper submitted on Arixiv](https://arxiv.org/abs/2506.13355)**                                       |  2025-6-17 |
|   ✅   | **[Test data release](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset)**       |  2025-6-25 |
|   ✅   | **[Training Code release]()**                                                                           |  2025-6-26 |



## ⚙️ Installation

- System requirement: PyTorch version >=2.4.1, python == 3.10
- Tested on GPUs: A800, python version == 3.10, PyTorch version == 2.4.1, cuda version == 12.1

Download the codes:

```bash
  git clone https://github.com/fudan-generative-vision/DicFace
  cd DicFace
```

Create conda environment:

```bash
  conda create -n DicFace python=3.10
  conda activate DicFace
```

Install PyTorch

```bash
  conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia
```

Install packages with `pip`

```bash
  pip install -r requirements.txt
  python basicsr/setup.py develop
  conda install -c conda-forge dlib
```

### 📥 Download Pretrained Models

The pre-trained weights have been uploaded to Baidu Netdisk. Please download them from the [link](https://pan.baidu.com/s/1VTNbdtZDvgY0163a1T8ITw?pwd=dicf)

Now you can easily get all pretrained models required by inference from our HuggingFace [repo](https://huggingface.co/fudan-generative-ai/DicFace_model).

**File Structure of Pretrained Models**
The downloaded .ckpts directory contains the following pre-trained models:

```
.ckpts
|-- CodeFormer                  # CodeFormer-related models
|   |-- bfr_100k.pth            # Blind Face Restoration model 
|   |-- color_100k.pth          # Color Restoration model 
|   |-- codeformer.pth          # codeformer model
|   |-- vqgan_discriminator.pth # vqgan_discriminator model
|   `-- inpainting_100k.pth     # Image Inpainting model
|-- dlib                        # dlib face-related models
|   |-- mmod_human_face_detector.dat  # Human face detector
|   `-- shape_predictor_5_face_landmarks.dat  # 5-point face landmark predictor
|-- facelib                     # Face processing library models
|   |-- detection_Resnet50_Final.pth  # ResNet50 face detector 
|   |-- detection_mobilenet0.25_Final.pth  # MobileNet0.25 face detector 
|   |-- parsing_parsenet.pth    # Face parsing model
|   |-- yolov5l-face.pth        # YOLOv5l face detection model
|   `-- yolov5n-face.pth        # YOLOv5n face detection model
|-- realesrgan                  # Real-ESRGAN super-resolution model
|   `-- RealESRGAN_x2plus.pth   # 2x super-resolution enhancement model
`-- vgg                         # VGG feature extraction model
    `-- vgg.pth                 # VGG network pre-trained weights
```

### 🎮 Run Inference

#### for blind face restoration

```bash
python scripts/inference.py \
		-i /path/to/video \
		-o /path/to/output_folder \
		--max_length 10 \
		--save_video_fps 24 \
		--ckpt_path /bfr/bfr_weight.pth \
		--bg_upsampler realesrgan \
		--save_video 

# or your videos has been aligned
python scripts/inference.py \
		-i /path/to/video \
		-o /path/to/output_folder \
		--max_length 10 \
		--save_video_fps 24 \
		--ckpt_path /bfr/bfr_weight.pth \
		--save_video \
		--has_aligned
```

#### for colorization & inpainting task


**The current colorization & inpainting tasks only supports input of aligned faces. If a non-aligned face is input, it may lead to unsatisfactory final results.**

``` bash 
# for colorization task
python scripts/inference_color_and_inpainting.py \
		-i /path/to/video_warped \
		-o /path/to/output_folder \
		--max_length 10 \
		--save_video_fps 24 \
		--ckpt_path /colorization/colorization_weight.pth \
		--bg_upsampler realesrgan \
		--save_video \
		--has_aligned

# for inpainting task
python scripts/inference_color_and_inpainting.py \
		-i /path/to/video_warped \
		-o /path/to/output_folder \
		--max_length 10 \
		--save_video_fps 24 \
		--ckpt_path /inpainting/inpainting_weight.pth \
		--bg_upsampler realesrgan \
		--save_video \
		--has_aligned
```

## Test Data  

Our test data can be accessed via the following links:  
- Baidu Netdisk: [https://pan.baidu.com/s/1zMp3fnf6LvlRT9CAoL1OUw](https://pan.baidu.com/s/1zMp3fnf6LvlRT9CAoL1OUw) (Password: `drhh`)  
- Hugging Face Dataset: [https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset)  


### Directory Structure  
The downloaded `test_data_set` directory contains the following folders:  
```
./test_data
├── LR_Blind                  # Blind face restoration test image folders
│   ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│   ├── ...                   # Additional test image folders
│   └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── TEST_DATA                 # Ground-truth (GT) image folders
│   ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│   ├── ...
│   └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── vfhq_test_color_input     # Colorization test image folders
│   ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│   ├── ...
│   └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── vfhq_test_inpaint_input_512  # Inpainting test image folders (512x512)
│   ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│   ├── ...
│   └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
└── vfhq_test_landmarks       # Facial landmark files for warping operations
```


### Usage  
To process the test data, use the `warp_images.py` script:  
```shell
python scripts/warp_images.py \
    -i input_test_data_folder \
    -o vfhq_test_inpaint_input_512_warped \
    -l /path/to/test_data_folder/vfhq_test_landmarks
```  

After warping the test data, you can use the inference scripts to generate results for the test dataset.


### Training

#### Training Data
We utilize the VFHQ dataset for both training and testing. The test data is specifically sourced from VFHQ-Test. For more details, please refer to the official project page: [VFHQ](https://liangbinxie.github.io/projects/vfhq/).

### Prerequisites for Training
Before initiating the training process, ensure that you have completed the following steps:

1. **Image Size Requirement**:
   - All input images must be resized to 512 x 512 pixels.

2. **Download Necessary Files**:
   - Obtain the metadata files and facial landmark information from our Hugging Face repository. [TBD(not ready)]

3. **Configure YAML Files**:
   - Edit the configuration file located at `options/xxx.yaml` to specify your training parameters and dataset paths.

### Initiate Training
Once the prerequisites are met, start the training process by executing the following command:
```bash
bash train.sh
```

This script will initiate the training procedure using the settings defined in your YAML configuration file.


## 🤗 Acknowledgements

This project is open sourced under NTU S-Lab License 1.0. Redistribution and use should follow this license. The code framework is mainly modified from [CodeFormer](https://github.com/sczhou/CodeFormer). Please refer to the original repo for more usage and documents.

## 📝 Citation

If you find our work useful for your research, please consider citing the paper:

```
@misc{chen2025dicfacedirichletconstrainedvariationalcodebook,
      title={DicFace: Dirichlet-Constrained Variational Codebook Learning for Temporally Coherent Video Face Restoration}, 
      author={Yan Chen and Hanlin Shang and Ce Liu and Yuxuan Chen and Hui Li and Weihao Yuan and Hao Zhu and Zilong Dong and Siyu Zhu},
      year={2025},
      eprint={2506.13355},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2506.13355}, 
}

```



================================================
FILE: basicsr/VERSION
================================================
1.3.2


================================================
FILE: basicsr/__init__.py
================================================
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .metrics import *
from .models import *
from .ops import *
from .train import *
from .utils import *
from .version import __gitsha__, __version__


================================================
FILE: basicsr/archs/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp

from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY

__all__ = ['build_network']

# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]


def build_network(opt):
    opt = deepcopy(opt)
    network_type = opt.pop('type')
    net = ARCH_REGISTRY.get(network_type)(**opt)
    logger = get_root_logger()
    logger.info(f'Network [{net.__class__.__name__}] is created.')
    return net


================================================
FILE: basicsr/archs/arcface_arch.py
================================================
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY


def conv3x3(inplanes, outplanes, stride=1):
    """A simple wrapper for 3x3 convolution with padding.

    Args:
        inplanes (int): Channel number of inputs.
        outplanes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
    """
    return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    """Basic residual block used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
    """
    expansion = 1  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class IRBlock(nn.Module):
    """Improved residual block (IR Block) used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
    """
    expansion = 1  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super(IRBlock, self).__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

    def forward(self, x):
        residual = x
        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.prelu(out)

        return out


class Bottleneck(nn.Module):
    """Bottleneck block used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
    """
    expansion = 4  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBlock(nn.Module):
    """The squeeze-and-excitation block (SEBlock) used in the IRBlock.

    Args:
        channel (int): Channel number of inputs.
        reduction (int): Channel reduction ration. Default: 16.
    """

    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # pool to 1x1 without spatial information
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
            nn.Sigmoid())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
    """ArcFace with ResNet architectures.

    Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.

    Args:
        block (str): Block used in the ArcFace architecture.
        layers (tuple(int)): Block numbers in each layer.
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
    """

    def __init__(self, block, layers, use_se=True):
        if block == 'IRBlock':
            block = IRBlock
        self.inplanes = 64
        self.use_se = use_se
        super(ResNetArcFace, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn4 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc5 = nn.Linear(512 * 8 * 8, 512)
        self.bn5 = nn.BatchNorm1d(512)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for _ in range(1, num_blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn4(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc5(x)
        x = self.bn5(x)

        return x

================================================
FILE: basicsr/archs/arch_util.py
================================================
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm

from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger


@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class ResidualBlockNoBN(nn.Module):
    """Residual block without BN.

    It has a style of:
        ---Conv-ReLU-Conv-+-
         |________________|

    Args:
        num_feat (int): Channel number of intermediate features.
            Default: 64.
        res_scale (float): Residual scale. Default: 1.
        pytorch_init (bool): If set to True, use pytorch default init,
            otherwise, use default_init_weights. Default: False.
    """

    def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
        super(ResidualBlockNoBN, self).__init__()
        self.res_scale = res_scale
        self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.relu = nn.ReLU(inplace=True)

        if not pytorch_init:
            default_init_weights([self.conv1, self.conv2], 0.1)

    def forward(self, x):
        identity = x
        out = self.conv2(self.relu(self.conv1(x)))
        return identity + out * self.res_scale


class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)


def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
    """Warp an image or feature map with optical flow.

    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2), normal value.
        interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
        padding_mode (str): 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Before pytorch 1.3, the default value is
            align_corners=True. After pytorch 1.3, the default value is
            align_corners=False. Here, we use the True as default.

    Returns:
        Tensor: Warped image or feature map.
    """
    assert x.size()[-2:] == flow.size()[1:3]
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False

    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)

    # TODO, what if align_corners=False
    return output


def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
    """Resize a flow according to ratio or shape.

    Args:
        flow (Tensor): Precomputed flow. shape [N, 2, H, W].
        size_type (str): 'ratio' or 'shape'.
        sizes (list[int | float]): the ratio for resizing or the final output
            shape.
            1) The order of ratio should be [ratio_h, ratio_w]. For
            downsampling, the ratio should be smaller than 1.0 (i.e., ratio
            < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
            ratio > 1.0).
            2) The order of output_size should be [out_h, out_w].
        interp_mode (str): The mode of interpolation for resizing.
            Default: 'bilinear'.
        align_corners (bool): Whether align corners. Default: False.

    Returns:
        Tensor: Resized flow.
    """
    _, _, flow_h, flow_w = flow.size()
    if size_type == 'ratio':
        output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
    elif size_type == 'shape':
        output_h, output_w = sizes[0], sizes[1]
    else:
        raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')

    input_flow = flow.clone()
    ratio_h = output_h / flow_h
    ratio_w = output_w / flow_w
    input_flow[:, 0, :, :] *= ratio_w
    input_flow[:, 1, :, :] *= ratio_h
    resized_flow = F.interpolate(
        input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
    return resized_flow


# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)


class DCNv2Pack(ModulatedDeformConvPack):
    """Modulated deformable conv for deformable alignment.

    Different from the official DCNv2Pack, which generates offsets and masks
    from the preceding features, this DCNv2Pack takes another different
    features to generate offsets and masks.

    Ref:
        Delving Deep into Deformable Alignment in Video Super-Resolution.
    """

    def forward(self, x, feat):
        out = self.conv_offset(feat)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)

        offset_absmean = torch.mean(torch.abs(offset))
        if offset_absmean > 50:
            logger = get_root_logger()
            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')

        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
                                                 self.dilation, mask)
        else:
            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
                                         self.dilation, self.groups, self.deformable_groups)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
            'The distribution of values may be incorrect.',
            stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        low = norm_cdf((a - mean) / std)
        up = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [low, up], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * low - 1, 2 * up - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution.

    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py

    The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


# From PyTorch
def _ntuple(n):

    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

================================================
FILE: basicsr/archs/dir_dist_codeformer_multiscale_arch.py
================================================
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List

from basicsr.archs.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY

import torch.distributions as dist

from einops import rearrange

def calc_mean_std(feat, eps=1e-5):
    """Calculate mean and std for adaptive_instance_normalization.

    Args:
        feat (Tensor): 4D tensor.
        eps (float): A small value added to the variance to avoid
            divide-by-zero. Default: 1e-5.
    """
    size = feat.size()
    assert len(size) == 4, 'The input feature should be 4D tensor.'
    b, c = size[:2]
    feat_var = feat.view(b, c, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(b, c, 1, 1)
    feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    """Adaptive instance normalization.

    Adjust the reference features to have the similar color and illuminations
    as those in the degradate features.

    Args:
        content_feat (Tensor): The reference feature.
        style_feat (Tensor): The degradate features.
    """
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)
    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask=None):
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")


class TransformerSALayer(nn.Module):
    def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
        # Implementation of Feedforward model - MLP
        self.linear1 = nn.Linear(embed_dim, dim_mlp)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_mlp, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(
        self,
        tgt,
        tgt_mask: Optional[Tensor] = None,
        tgt_key_padding_mask: Optional[Tensor] = None,
        query_pos: Optional[Tensor] = None,
    ):

        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q,
                              k,
                              value=tgt2,
                              attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)

        # ffn
        tgt2 = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout2(tgt2)
        return tgt
 
 
class TransformerSALayerTemporal(nn.Module):
    def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
        super().__init__()

        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
        # Implementation of Feedforward model - MLP
        self.linear1 = nn.Linear(embed_dim, dim_mlp)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_mlp, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self,
                tgt,
                frame_length=10,
                batch_size=1,
                tgt_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):

        tgt = rearrange(tgt, "d (b t) c -> t (b d) c", t=frame_length)

        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q,
                              k,
                              value=tgt2,
                              attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)

        # ffn
        tgt2 = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout2(tgt2)
        # reshape
        tgt = rearrange(tgt, "t (b d) c -> d (b t) c", b=batch_size)

        return tgt


class Fuse_sft_block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.encode_enc = ResBlock(2*in_ch, out_ch)

        self.scale = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                    nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))

        self.shift = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                    nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))

    def forward(self, enc_feat, dec_feat, w=1):
        enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
        scale = self.scale(enc_feat)
        shift = self.shift(enc_feat)
        residual = w * (dec_feat * scale + shift)
        out = dec_feat + residual
        return out

class ExpModule(nn.Module):
    def forward(self, x):
        return torch.exp(x)


class MultiScaleFuse(nn.Module):
    def __init__(self):
        super(MultiScaleFuse, self).__init__()
        self.s64_conv = nn.Conv2d(in_channels=256*16, out_channels=256, kernel_size=1)
        self.s32_conv = nn.Conv2d(in_channels=256*4, out_channels=256, kernel_size=1)
        self.s16_conv = nn.Conv2d(in_channels=256*1, out_channels=256, kernel_size=1)
        self.out = nn.Conv2d(in_channels=256*3, out_channels=256, kernel_size=3, stride=1, padding=1)

    def forward(self, s64, s32, s16):

        feat_64 = rearrange(s64, "bt c (h h1) (w w1) -> bt (c h1 w1) h w", h1=4, w1=4) 
        feat_64 = self.s64_conv(feat_64)
        feat_32 = rearrange(s32, "bt c (h h1) (w w1) -> bt (c h1 w1) h w", h1=2, w1=2) 
        feat_32 = self.s32_conv(feat_32)
        feat_16 = self.s16_conv(s16)

        out = self.out(torch.concat([feat_64, feat_32, feat_16], dim=1))
        return out

@ARCH_REGISTRY.register()
class TemporalCodeFormerDirDistMultiScale(VQAutoEncoder):
    def __init__(self,
                 dim_embed=512,
                 n_head=8,
                 n_layers=9, 
                 codebook_size=1024,
                 latent_size=256,
                 connect_list=['32', '64', '128', '256'],
                 fix_modules=['quantize','generator'],
                 vqgan_path=None,
                 frame_length=10,
                 new_codebook_size=None):
        super(TemporalCodeFormerDirDistMultiScale, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', 2, [16], codebook_size)

        if vqgan_path is not None:
            self.load_state_dict(
                torch.load(vqgan_path, map_location='cpu')['params_ema'])

        self.frame_length = frame_length

        self.connect_list = connect_list
        self.n_layers = n_layers
        self.dim_embed = dim_embed
        self.dim_mlp = dim_embed * 2

        self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embed))
        self.position_emb_temporal = nn.Parameter(torch.zeros(self.frame_length, self.dim_embed))
        self.feat_emb = nn.Linear(256, self.dim_embed)

        self.codebook_size = codebook_size
        self.new_codebook_size = None
        if new_codebook_size is not None:
            self.new_codebook_size = new_codebook_size
            self.codebook_size += new_codebook_size
            self.new_codebook = nn.Parameter(torch.normal(mean=0, std=0.75, size=(new_codebook_size, 256)))
            self.new_codebook.requires_grad = True

        self.multiscale = MultiScaleFuse()

        # transformer in Space
        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embed,
                                                            nhead=n_head,
                                                            dim_mlp=self.dim_mlp,
                                                            dropout=0.1)
                                    for _ in range(self.n_layers)])
        # transformer in Temporal
        self.dir_dist_layers = nn.Sequential(*[TransformerSALayerTemporal(embed_dim=dim_embed,
                                                                          nhead=n_head,
                                                                          dim_mlp=self.dim_mlp,
                                                                          dropout=0.1)
                                    for _ in range(self.n_layers)])

        # logits_predict head
        self.idx_pred_layer = nn.Sequential(
            nn.LayerNorm(dim_embed),
            nn.Linear(dim_embed, self.codebook_size, bias=False),
        )

        self.channels = {
            '16': 512,
            '32': 256,
            '64': 256,
            '128': 128,
            '256': 128,
            '512': 64,
        }


        self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
        self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}

        # fuse_convs_dict
        self.fuse_convs_dict = nn.ModuleDict()
        for f_size in self.connect_list:
            in_ch = self.channels[f_size]
            self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)

        self.softplus_layer = nn.Softplus()
        self.position_emb.requires_grad = False
        print("Module: position_emb_spatial Frozen!")

        if fix_modules is not None:
            print(fix_modules, "frozen!")
            for module in fix_modules:
                for param_name, param in getattr(self, module).named_parameters():
                    if "conv3d" in param_name:
                        param.requires_grad = True
                    else:
                        # print(f"Module: {module}, Parameter name: {param_name} Frozen!")
                        param.requires_grad = False

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
        # ################### Encoder #####################
        enc_feat_dict = {}
        out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
        for i, block in enumerate(self.encoder.blocks):
            x = block(x) 
            if i in out_list:
                enc_feat_dict[str(x.shape[-1])] = x.clone()

        lq_feat = self.multiscale(enc_feat_dict['64'], enc_feat_dict['32'], x)

        bt, c, h, width = lq_feat.shape
        b = bt // self.frame_length
        t = self.frame_length
        # ################# Spatial & Temporal Transformers ###################
        spatial_pos_emb = self.position_emb.unsqueeze(1).repeat(1, bt, 1)
        temporal_pos_emb = self.position_emb_temporal.unsqueeze(1).repeat(1, b*h*width, 1)
        feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
        query_emb = feat_emb

        for layer_space, layer_temporal in zip(self.ft_layers, self.dir_dist_layers):
            query_emb = layer_space(query_emb, query_pos=spatial_pos_emb)
            query_emb = layer_temporal(query_emb, query_pos=temporal_pos_emb, frame_length=t, batch_size=b)

        alpha = self.idx_pred_layer(query_emb)
        alpha = alpha.permute(1, 0, 2)
        alpha = self.softplus_layer(alpha) + 1e-2

        dirichlet_dist = dist.Dirichlet(alpha)
        parameters = dirichlet_dist.rsample()

        parameters_reshaped = parameters.reshape(-1, self.codebook_size)

        if self.new_codebook_size is not None:
            quant_feat = torch.matmul(parameters_reshaped[:, :-self.new_codebook_size], self.quantize.embedding.weight) + \
                 torch.matmul(parameters_reshaped[:, -self.new_codebook_size:], self.new_codebook)
        else:
            quant_feat = torch.matmul(parameters_reshaped, self.quantize.embedding.weight) 

        quant_feat = rearrange(quant_feat, "(b t h w) c -> (b t) c h w", b=b, t=t, h=h, w=width)


        if adain:
            quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)

        # ################## Generator ####################
        x = quant_feat
        fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]

        for i, block in enumerate(self.generator.blocks):
            x = block(x) 
            if i in fuse_list:
                f_size = str(x.shape[-1])
                if w > 0:
                    x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
        out = x
        return out, lq_feat, alpha + 1e-6


================================================
FILE: basicsr/archs/rrdbnet_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F

from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import default_init_weights, make_layer, pixel_unshuffle


class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Emperically, we use 0.2 to scale the residual for better performance
        return x5 * 0.2 + x


class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Emperically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x


@ARCH_REGISTRY.register()
class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23
        num_grow_ch (int): Channels for each growth. Default: 32.
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()
        self.scale = scale
        if scale == 2:
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            num_in_ch = num_in_ch * 16
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        else:
            feat = x
        feat = self.conv_first(feat)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

================================================
FILE: basicsr/archs/vgg_arch.py
================================================
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg

from basicsr.utils.registry import ARCH_REGISTRY

VGG_PRETRAIN_PATH = './ckpts/vgg/vgg16-397923af.pth'
NAMES = {
    'vgg11': [
        'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
        'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
        'pool5'
    ],
    'vgg13': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
    ],
    'vgg16': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
        'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
        'pool5'
    ],
    'vgg19': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
        'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
        'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
    ]
}


def insert_bn(names):
    """Insert bn layer after each conv.

    Args:
        names (list): The list of layer names.

    Returns:
        list: The list of layer names with bn layers.
    """
    names_bn = []
    for name in names:
        names_bn.append(name)
        if 'conv' in name:
            position = name.replace('conv', '')
            names_bn.append('bn' + position)
    return names_bn


@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
    """VGG network for feature extraction.

    In this implementation, we allow users to choose whether use normalization
    in the input feature and the type of vgg network. Note that the pretrained
    path must fit the vgg type.

    Args:
        layer_name_list (list[str]): Forward function returns the corresponding
            features according to the layer_name_list.
            Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
        vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
        use_input_norm (bool): If True, normalize the input image. Importantly,
            the input feature must in the range [0, 1]. Default: True.
        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
            Default: False.
        requires_grad (bool): If true, the parameters of VGG network will be
            optimized. Default: False.
        remove_pooling (bool): If true, the max pooling operations in VGG net
            will be removed. Default: False.
        pooling_stride (int): The stride of max pooling operation. Default: 2.
    """

    def __init__(self,
                 layer_name_list,
                 vgg_type='vgg19',
                 use_input_norm=True,
                 range_norm=False,
                 requires_grad=False,
                 remove_pooling=False,
                 pooling_stride=2):
        super(VGGFeatureExtractor, self).__init__()

        self.layer_name_list = layer_name_list
        self.use_input_norm = use_input_norm
        self.range_norm = range_norm

        self.names = NAMES[vgg_type.replace('_bn', '')]
        if 'bn' in vgg_type:
            self.names = insert_bn(self.names)

        # only borrow layers that will be used to avoid unused params
        max_idx = 0
        for v in layer_name_list:
            idx = self.names.index(v)
            if idx > max_idx:
                max_idx = idx

        if os.path.exists(VGG_PRETRAIN_PATH):
            vgg_net = getattr(vgg, vgg_type)(pretrained=False)
            state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
            vgg_net.load_state_dict(state_dict)
        else:
            vgg_net = getattr(vgg, vgg_type)(pretrained=True)
        
        features = vgg_net.features[:max_idx + 1]

        modified_net = OrderedDict()
        for k, v in zip(self.names, features):
            if 'pool' in k:
                # if remove_pooling is true, pooling operation will be removed
                if remove_pooling:
                    continue
                else:
                    # in some cases, we may want to change the default stride
                    modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
            else:
                modified_net[k] = v

        self.vgg_net = nn.Sequential(modified_net)

        if not requires_grad:
            self.vgg_net.eval()
            for param in self.parameters():
                param.requires_grad = False
        else:
            self.vgg_net.train()
            for param in self.parameters():
                param.requires_grad = True

        if self.use_input_norm:
            # the mean is for image with range [0, 1]
            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            # the std is for image with range [0, 1]
            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (n, c, h, w).

        Returns:
            Tensor: Forward results.
        """
        if self.range_norm:
            x = (x + 1) / 2
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = {}

        for key, layer in self.vgg_net._modules.items():
            x = layer(x)
            if key in self.layer_name_list:
                output[key] = x.clone()

        return output


================================================
FILE: basicsr/archs/vqgan_arch.py
================================================
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py

'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY

def normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    

@torch.jit.script
def swish(x):
    return x*torch.sigmoid(x)


#  Define VQVAE classes
class VectorQuantizer(nn.Module):
    def __init__(self, codebook_size, emb_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.codebook_size = codebook_size  # number of embeddings
        self.emb_dim = emb_dim  # dimension of embedding
        self.beta = beta  # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
        self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)

    def forward(self, z):
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.emb_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
            2 * torch.matmul(z_flattened, self.embedding.weight.t())

        mean_distance = torch.mean(d)
        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
        # [0-1], higher score, higher confidence
        # min_encoding_scores = torch.exp(-min_encoding_scores/10)

        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
        # compute loss for embedding
        loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, {
            "perplexity": perplexity,
            "min_encodings": min_encodings,
            "min_encoding_indices": min_encoding_indices,
            "mean_distance": mean_distance
            }

    def get_codebook_feat(self, indices, shape):
        # input indices: batch*token_num -> (batch*token_num)*1
        # shape: batch, height, width, channel
        indices = indices.view(-1,1)
        min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
        min_encodings.scatter_(1, indices, 1)
        # get quantized latent vectors
        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)

        if shape is not None:  # reshape back to match original input shape
            z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()

        return z_q


class GumbelQuantizer(nn.Module):
    def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
        super().__init__()
        self.codebook_size = codebook_size  # number of embeddings
        self.emb_dim = emb_dim  # dimension of embedding
        self.straight_through = straight_through
        self.temperature = temp_init
        self.kl_weight = kl_weight
        self.proj = nn.Conv2d(num_hiddens, codebook_size, 1)  # projects last encoder layer to quantized logits
        self.embed = nn.Embedding(codebook_size, emb_dim)

    def forward(self, z):
        hard = self.straight_through if self.training else True

        logits = self.proj(z)

        soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)

        z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)

        # + kl divergence to the prior loss
        qy = F.softmax(logits, dim=1)
        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
        min_encoding_indices = soft_one_hot.argmax(dim=1)

        return z_q, diff, {
            "min_encoding_indices": min_encoding_indices
        }


class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.conv(x)
        return x


class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
        x = self.conv(x)

        return x


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super(ResBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm1 = normalize(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = normalize(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x_in):
        x = x_in
        x = self.norm1(x)
        x = swish(x)
        x = self.conv1(x)
        x = self.norm2(x)
        x = swish(x)
        x = self.conv2(x)
        if self.in_channels != self.out_channels:
            x_in = self.conv_out(x_in)

        return x + x_in


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = normalize(in_channels)
        self.q = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.k = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.v = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.proj_out = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)   
        k = k.reshape(b, c, h*w)
        w_ = torch.bmm(q, k) 
        w_ = w_ * (int(c)**(-0.5))
        w_ = F.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h*w)
        w_ = w_.permute(0, 2, 1) 
        h_ = torch.bmm(v, w_)
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x+h_


class Encoder(nn.Module):
    def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
        super().__init__()
        self.nf = nf
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.attn_resolutions = attn_resolutions

        curr_res = self.resolution
        in_ch_mult = (1,)+tuple(ch_mult)

        blocks = []
        # initial convultion
        blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))

        # residual and downsampling blocks, with attention on smaller res (16x16)
        for i in range(self.num_resolutions):
            block_in_ch = nf * in_ch_mult[i]
            block_out_ch = nf * ch_mult[i]
            for _ in range(self.num_res_blocks):
                blocks.append(ResBlock(block_in_ch, block_out_ch))
                block_in_ch = block_out_ch
                if curr_res in attn_resolutions:
                    blocks.append(AttnBlock(block_in_ch))

            if i != self.num_resolutions - 1:
                blocks.append(Downsample(block_in_ch))
                curr_res = curr_res // 2

        # non-local attention block
        blocks.append(ResBlock(block_in_ch, block_in_ch))
        blocks.append(AttnBlock(block_in_ch))
        blocks.append(ResBlock(block_in_ch, block_in_ch))

        # normalise and convert to latent size
        blocks.append(normalize(block_in_ch))
        blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
            
        return x


class Generator(nn.Module):
    def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
        super().__init__()
        self.nf = nf 
        self.ch_mult = ch_mult 
        self.num_resolutions = len(self.ch_mult)
        self.num_res_blocks = res_blocks
        self.resolution = img_size 
        self.attn_resolutions = attn_resolutions
        self.in_channels = emb_dim
        self.out_channels = 3
        block_in_ch = self.nf * self.ch_mult[-1]
        curr_res = self.resolution // 2 ** (self.num_resolutions-1)

        blocks = []
        # initial conv
        blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))

        # non-local attention block
        blocks.append(ResBlock(block_in_ch, block_in_ch))
        blocks.append(AttnBlock(block_in_ch))
        blocks.append(ResBlock(block_in_ch, block_in_ch))

        for i in reversed(range(self.num_resolutions)):
            block_out_ch = self.nf * self.ch_mult[i]

            for _ in range(self.num_res_blocks):
                blocks.append(ResBlock(block_in_ch, block_out_ch))
                block_in_ch = block_out_ch

                if curr_res in self.attn_resolutions:
                    blocks.append(AttnBlock(block_in_ch))

            if i != 0:
                blocks.append(Upsample(block_in_ch))
                curr_res = curr_res * 2

        blocks.append(normalize(block_in_ch))
        blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))

        self.blocks = nn.ModuleList(blocks)
   

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
            
        return x

  
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
                beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
        super().__init__()
        logger = get_root_logger()
        self.in_channels = 3 
        self.nf = nf 
        self.n_blocks = res_blocks 
        self.codebook_size = codebook_size
        self.embed_dim = emb_dim
        self.ch_mult = ch_mult
        self.resolution = img_size
        self.attn_resolutions = attn_resolutions
        self.quantizer_type = quantizer
        self.encoder = Encoder(
            self.in_channels,
            self.nf,
            self.embed_dim,
            self.ch_mult,
            self.n_blocks,
            self.resolution,
            self.attn_resolutions
        )
        if self.quantizer_type == "nearest":
            self.beta = beta #0.25
            self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
        elif self.quantizer_type == "gumbel":
            self.gumbel_num_hiddens = emb_dim
            self.straight_through = gumbel_straight_through
            self.kl_weight = gumbel_kl_weight
            self.quantize = GumbelQuantizer(
                self.codebook_size,
                self.embed_dim,
                self.gumbel_num_hiddens,
                self.straight_through,
                self.kl_weight
            )
        self.generator = Generator(
            self.nf, 
            self.embed_dim,
            self.ch_mult, 
            self.n_blocks, 
            self.resolution, 
            self.attn_resolutions
        )

        if model_path is not None:
            chkpt = torch.load(model_path, map_location='cpu')
            if 'params_ema' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
                logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
            elif 'params' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
                logger.info(f'vqgan is loaded from: {model_path} [params]')
            else:
                raise ValueError(f'Wrong params!')


    def forward(self, x):
        x = self.encoder(x)
        quant, codebook_loss, quant_stats = self.quantize(x)
        x = self.generator(quant)
        return x, codebook_loss, quant_stats



# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
    def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
        super().__init__()

        layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
        ndf_mult = 1
        ndf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            ndf_mult_prev = ndf_mult
            ndf_mult = min(2 ** n, 8)
            layers += [
                nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(ndf * ndf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        ndf_mult_prev = ndf_mult
        ndf_mult = min(2 ** n_layers, 8)

        layers += [
            nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ndf * ndf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        layers += [
            nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)]  # output 1 channel prediction map
        self.main = nn.Sequential(*layers)

        if model_path is not None:
            chkpt = torch.load(model_path, map_location='cpu')
            if 'params_d' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
            elif 'params' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
            else:
                raise ValueError(f'Wrong params!')

    def forward(self, x):
        return self.main(x)

================================================
FILE: basicsr/data/__init__.py
================================================
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp

from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY

__all__ = ['build_dataset', 'build_dataloader']

# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]


def build_dataset(dataset_opt):
    """Build dataset from options.

    Args:
        dataset_opt (dict): Configuration for dataset. It must constain:
            name (str): Dataset name.
            type (str): Dataset type.
    """
    dataset_opt = deepcopy(dataset_opt)
    dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
    logger = get_root_logger()
    logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
    return dataset


def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
    """Build dataloader.

    Args:
        dataset (torch.utils.data.Dataset): Dataset.
        dataset_opt (dict): Dataset options. It contains the following keys:
            phase (str): 'train' or 'val'.
            num_worker_per_gpu (int): Number of workers for each GPU.
            batch_size_per_gpu (int): Training batch size for each GPU.
        num_gpu (int): Number of GPUs. Used only in the train phase.
            Default: 1.
        dist (bool): Whether in distributed training. Used only in the train
            phase. Default: False.
        sampler (torch.utils.data.sampler): Data sampler. Default: None.
        seed (int | None): Seed. Default: None
    """
    phase = dataset_opt['phase']
    rank, _ = get_dist_info()
    if phase == 'train':
        if dist:  # distributed training
            batch_size = dataset_opt['batch_size_per_gpu']
            num_workers = dataset_opt['num_worker_per_gpu']
        else:  # non-distributed training
            multiplier = 1 if num_gpu == 0 else num_gpu
            batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
            num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
        dataloader_args = dict(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            sampler=sampler,
            drop_last=True)
        if sampler is None:
            dataloader_args['shuffle'] = True
        dataloader_args['worker_init_fn'] = partial(
            worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
    elif phase in ['val', 'test']:  # validation
        dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
    else:
        raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")

    dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)

    prefetch_mode = dataset_opt.get('prefetch_mode')
    if prefetch_mode == 'cpu':  # CPUPrefetcher
        num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
        logger = get_root_logger()
        logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
        return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
    else:
        # prefetch_mode=None: Normal dataloader
        # prefetch_mode='cuda': dataloader for CUDAPrefetcher
        return torch.utils.data.DataLoader(**dataloader_args)


def worker_init_fn(worker_id, num_workers, rank, seed):
    # Set the worker seed to num_workers * rank + worker_id + seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)


================================================
FILE: basicsr/data/color_dataset.py
================================================
import os
import random
from pathlib import Path

from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
                                       random_mixed_kernels)
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, brush_stroke_mask_video, random_ff_mask
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data

@DATASET_REGISTRY.register()
class ColorizationDataset(data.Dataset):
    def __init__(self, opt):
        super(ColorizationDataset, self).__init__()
        self.opt = opt
        self.gt_root = Path(opt['dataroot_gt'])

        self.num_frame = opt['video_length'] # 5
        self.scale = opt['scale'] # [1, 4]
        self.need_align = opt.get('need_align', False) # False
        self.normalize = opt.get('normalize', False) # True

        self.keys = []
        with open(opt['global_meta_info_file'], 'r') as fin:
            for line in fin:
                real_clip_path = '/'.join(line.split('/')[:-1])
                clip_length = int(line.split('/')[-1])
                self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])

        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.is_lmdb = False
        if self.io_backend_opt['type'] == 'lmdb':
            self.is_lmdb = True
            self.io_backend_opt['db_paths'] = [self.gt_root]
            self.io_backend_opt['client_keys'] = ['gt']

        # temporal augmentation configs
        self.interval_list = opt['interval_list'] # [1]
        self.random_reverse = opt['random_reverse']
        interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
        logger = get_root_logger()
        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
                    f'random reverse is {self.random_reverse}.')

        # degradations
        # blur
        self.blur_kernel_size = opt['blur_kernel_size'] # 21
        self.kernel_list = opt['kernel_list']           # ['iso', 'aniso']
        self.kernel_prob = opt['kernel_prob']           # [0.5, 0.5]  
        self.blur_x_sigma = opt['blur_x_sigma']         # [0.2, 3]
        self.blur_y_sigma = opt['blur_y_sigma']         # [0.2, 3]
        # noise
        self.noise_range = opt['noise_range']           # [0, 25] 
        # resize
        self.resize_prob = opt['resize_prob']           # [0.25, 0.25, 0.5]
        # crf
        self.crf_range = opt['crf_range']               # [10, 30]
        # codec
        self.vcodec = opt['vcodec']                     # ['libx264']
        self.vcodec_prob = opt['vcodec_prob']           # [1]

        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
                    f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
                    f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
        logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
        logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
        logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')

        if self.need_align:
            self.dataroot_meta_info = opt['dataroot_meta_info']
            self.face_aligner = FaceAligner(
                upscale_factor=1,
                face_size=512,
                crop_ratio=(1, 1),
                det_model='retinaface_resnet50',
                save_ext='png',
                use_parse=True)

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        key = self.keys[index]
        real_clip_path = '/'.join(key.split('/')[:-2])
        clip_length = int(key.split('/')[-2])
        frame_idx = int(key.split('/')[-1])
        clip_name = real_clip_path.split('/')[-1]

        if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
        elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
        else:
            paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))

        # determine the neighboring frames
        interval = random.choice(self.interval_list)

        # exceed the length, re-select a new clip
        while (clip_length - self.num_frame * interval) < 0:
            interval = random.choice(self.interval_list)

        # ensure not exceeding the borders
        start_frame_idx = frame_idx - self.num_frame // 2 * interval
        end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval

        while (start_frame_idx < 0) or (end_frame_idx > clip_length):
            frame_idx = random.randint(self.num_frame // 2 * interval,
                                       clip_length - self.num_frame // 2 * interval)
            start_frame_idx = frame_idx - self.num_frame // 2 * interval
            end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
        neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))

        # random reverse
        if self.random_reverse and random.random() < 0.5:
            neighbor_list.reverse()

        assert len(neighbor_list) == self.num_frame, (
            f'Wrong length of neighbor list: {len(neighbor_list)}')

        # get the neighboring GT frames
        img_gts = []

        need_align = False
        if self.need_align:
            clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
            if os.path.exists(clip_info_path):
                need_align = True
                clip_info = []
                with open(clip_info_path, 'r', encoding='utf-8') as fin:
                    for line in fin:
                        line = line.strip()
                        clip_info.append(line)

        for neighbor in neighbor_list:
            img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])

            img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
            img_gts.append(img_gt)
            
        # augmentation - flip, rotate
        img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False

        # ------------- generate grayscale frames --------------#
        img_lqs = img_gts
        img_lqs = [cv2.cvtColor((_ * 255).astype('uint8'), cv2.COLOR_BGR2GRAY) for _ in  img_lqs]
        img_lqs = [np.repeat(_[..., None], repeats=3, axis=2) / 255. for _ in img_lqs]

        # -------------- Align ---------------#
        if need_align:
            align_lqs, align_gts = [], []
            for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
                landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
                landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
                self.face_aligner.clean_all()

                # align and warp each face
                img_lq, img_gt = self.face_aligner.align_pair_face(img_lq, img_gt, landmarks)
                align_lqs.append(img_lq)
                align_gts.append(img_gt)
            img_lqs, img_gts = align_lqs, align_gts

        img_gts = img2tensor(img_gts)
        img_lqs = img2tensor(img_lqs)
        img_gts = torch.stack(img_gts, dim=0)
        img_lqs = torch.stack(img_lqs, dim=0)

        if self.normalize:
            normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
            normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)

        return {'in': img_lqs, 'gt': img_gts, 'key': key}

    def __len__(self):
        return len(self.keys)


================================================
FILE: basicsr/data/data_sampler.py
================================================
import math
import torch
from torch.utils.data.sampler import Sampler


class EnlargedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    Modified from torch.utils.data.distributed.DistributedSampler
    Support enlarging the dataset for iteration-based training, for saving
    time when restart the dataloader after each epoch

    Args:
        dataset (torch.utils.data.Dataset): Dataset used for sampling.
        num_replicas (int | None): Number of processes participating in
            the training. It is usually the world_size.
        rank (int | None): Rank of the current process within num_replicas.
        ratio (int): Enlarging ratio. Default: 1.
    """

    def __init__(self, dataset, num_replicas, rank, ratio=1):
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()

        dataset_size = len(self.dataset)
        indices = [v % dataset_size for v in indices]

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


================================================
FILE: basicsr/data/data_util.py
================================================
import cv2
import math
import numpy as np
import torch
from os import path as osp
from PIL import Image, ImageDraw
from torch.nn import functional as F

from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir


def read_img_seq(path, require_mod_crop=False, scale=1):
    """Read a sequence of images from a given folder path.

    Args:
        path (list[str] | str): List of image paths or image folder path.
        require_mod_crop (bool): Require mod crop for each image.
            Default: False.
        scale (int): Scale factor for mod_crop. Default: 1.

    Returns:
        Tensor: size (t, c, h, w), RGB, [0, 1].
    """
    if isinstance(path, list):
        img_paths = path
    else:
        img_paths = sorted(list(scandir(path, full_path=True)))
    imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
    if require_mod_crop:
        imgs = [mod_crop(img, scale) for img in imgs]
    imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
    imgs = torch.stack(imgs, dim=0)
    return imgs


def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
    """Generate an index list for reading `num_frames` frames from a sequence
    of images.

    Args:
        crt_idx (int): Current center index.
        max_frame_num (int): Max number of the sequence of images (from 1).
        num_frames (int): Reading num_frames frames.
        padding (str): Padding mode, one of
            'replicate' | 'reflection' | 'reflection_circle' | 'circle'
            Examples: current_idx = 0, num_frames = 5
            The generated frame indices under different padding mode:
            replicate: [0, 0, 0, 1, 2]
            reflection: [2, 1, 0, 1, 2]
            reflection_circle: [4, 3, 0, 1, 2]
            circle: [3, 4, 0, 1, 2]

    Returns:
        list[int]: A list of indices.
    """
    assert num_frames % 2 == 1, 'num_frames should be an odd number.'
    assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'

    max_frame_num = max_frame_num - 1  # start from 0
    num_pad = num_frames // 2

    indices = []
    for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
        if i < 0:
            if padding == 'replicate':
                pad_idx = 0
            elif padding == 'reflection':
                pad_idx = -i
            elif padding == 'reflection_circle':
                pad_idx = crt_idx + num_pad - i
            else:
                pad_idx = num_frames + i
        elif i > max_frame_num:
            if padding == 'replicate':
                pad_idx = max_frame_num
            elif padding == 'reflection':
                pad_idx = max_frame_num * 2 - i
            elif padding == 'reflection_circle':
                pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
            else:
                pad_idx = i - num_frames
        else:
            pad_idx = i
        indices.append(pad_idx)
    return indices


def paired_paths_from_lmdb(folders, keys):
    """Generate paired paths from lmdb files.

    Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:

    lq.lmdb
    ├── data.mdb
    ├── lock.mdb
    ├── meta_info.txt

    The data.mdb and lock.mdb are standard lmdb files and you can refer to
    https://lmdb.readthedocs.io/en/release/ for more details.

    The meta_info.txt is a specified txt file to record the meta information
    of our datasets. It will be automatically created when preparing
    datasets by our provided dataset tools.
    Each line in the txt file records
    1)image name (with extension),
    2)image shape,
    3)compression level, separated by a white space.
    Example: `baboon.png (120,125,3) 1`

    We use the image name without extension as the lmdb key.
    Note that we use the same key for the corresponding lq and gt images.

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
            Note that this key is different from lmdb keys.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
        raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
                         f'formats. But received {input_key}: {input_folder}; '
                         f'{gt_key}: {gt_folder}')
    # ensure that the two meta_info files are the same
    with open(osp.join(input_folder, 'meta_info.txt')) as fin:
        input_lmdb_keys = [line.split('.')[0] for line in fin]
    with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
        gt_lmdb_keys = [line.split('.')[0] for line in fin]
    if set(input_lmdb_keys) != set(gt_lmdb_keys):
        raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
    else:
        paths = []
        for lmdb_key in sorted(input_lmdb_keys):
            paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
        return paths


def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
    """Generate paired paths from an meta information file.

    Each line in the meta information file contains the image names and
    image shape (usually for gt), separated by a white space.

    Example of an meta information file:
    ```
    0001_s001.png (480,480,3)
    0001_s002.png (480,480,3)
    ```

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
        meta_info_file (str): Path to the meta information file.
        filename_tmpl (str): Template for each filename. Note that the
            template excludes the file extension. Usually the filename_tmpl is
            for files in the input folder.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    with open(meta_info_file, 'r') as fin:
        gt_names = [line.split(' ')[0] for line in fin]

    paths = []
    for gt_name in gt_names:
        basename, ext = osp.splitext(osp.basename(gt_name))
        input_name = f'{filename_tmpl.format(basename)}{ext}'
        input_path = osp.join(input_folder, input_name)
        gt_path = osp.join(gt_folder, gt_name)
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
    return paths


def paired_paths_from_folder(folders, keys, filename_tmpl):
    """Generate paired paths from folders.

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
        filename_tmpl (str): Template for each filename. Note that the
            template excludes the file extension. Usually the filename_tmpl is
            for files in the input folder.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    input_paths = list(scandir(input_folder))
    gt_paths = list(scandir(gt_folder))
    assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
                                               f'{len(input_paths)}, {len(gt_paths)}.')
    paths = []
    for gt_path in gt_paths:
        basename, ext = osp.splitext(osp.basename(gt_path))
        input_name = f'{filename_tmpl.format(basename)}{ext}'
        input_path = osp.join(input_folder, input_name)
        assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
        gt_path = osp.join(gt_folder, gt_path)
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
    return paths


def paths_from_folder(folder):
    """Generate paths from folder.

    Args:
        folder (str): Folder path.

    Returns:
        list[str]: Returned path list.
    """

    paths = list(scandir(folder))
    paths = [osp.join(folder, path) for path in paths]
    return paths


def paths_from_lmdb(folder):
    """Generate paths from lmdb.

    Args:
        folder (str): Folder path.

    Returns:
        list[str]: Returned path list.
    """
    if not folder.endswith('.lmdb'):
        raise ValueError(f'Folder {folder}folder should in lmdb format.')
    with open(osp.join(folder, 'meta_info.txt')) as fin:
        paths = [line.split('.')[0] for line in fin]
    return paths


def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
    """Generate Gaussian kernel used in `duf_downsample`.

    Args:
        kernel_size (int): Kernel size. Default: 13.
        sigma (float): Sigma of the Gaussian kernel. Default: 1.6.

    Returns:
        np.array: The Gaussian kernel.
    """
    from scipy.ndimage import filters as filters
    kernel = np.zeros((kernel_size, kernel_size))
    # set element at the middle to one, a dirac delta
    kernel[kernel_size // 2, kernel_size // 2] = 1
    # gaussian-smooth the dirac, resulting in a gaussian filter
    return filters.gaussian_filter(kernel, sigma)


def duf_downsample(x, kernel_size=13, scale=4):
    """Downsamping with Gaussian kernel used in the DUF official code.

    Args:
        x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
        kernel_size (int): Kernel size. Default: 13.
        scale (int): Downsampling factor. Supported scale: (2, 3, 4).
            Default: 4.

    Returns:
        Tensor: DUF downsampled frames.
    """
    assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'

    squeeze_flag = False
    if x.ndim == 4:
        squeeze_flag = True
        x = x.unsqueeze(0)
    b, t, c, h, w = x.size()
    x = x.view(-1, 1, h, w)
    pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')

    gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
    gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
    x = F.conv2d(x, gaussian_filter, stride=scale)
    x = x[:, :, 2:-2, 2:-2]
    x = x.view(b, t, c, x.size(2), x.size(3))
    if squeeze_flag:
        x = x.squeeze(0)
    return x


def brush_stroke_mask(img, color=(255,255,255)):
    min_num_vertex = 8
    max_num_vertex = 28
    mean_angle = 2*math.pi / 5
    angle_range = 2*math.pi / 12
    # training large mask ratio (training setting)
    min_width = 30
    max_width = 70
    # very large mask ratio (test setting and refine after 200k)
    # min_width = 80
    # max_width = 120
    def generate_mask(H, W, img=None):
        average_radius = math.sqrt(H*H+W*W) / 8
        mask = Image.new('RGB', (W, H), 0)
        if img is not None:
            mask = img # Image.fromarray(img)

        for _ in range(np.random.randint(1, 4)):
            num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
            angle_min = mean_angle - np.random.uniform(0, angle_range)
            angle_max = mean_angle + np.random.uniform(0, angle_range)
            angles = []
            vertex = []
            for i in range(num_vertex):
                if i % 2 == 0:
                    angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
                else:
                    angles.append(np.random.uniform(angle_min, angle_max))

            h, w = mask.size
            vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
            for i in range(num_vertex):
                r = np.clip(
                    np.random.normal(loc=average_radius, scale=average_radius//2),
                    0, 2*average_radius)
                new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
                new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
                vertex.append((int(new_x), int(new_y)))

            draw = ImageDraw.Draw(mask)
            width = int(np.random.uniform(min_width, max_width))
            draw.line(vertex, fill=color, width=width)
            for v in vertex:
                draw.ellipse((v[0] - width//2,
                              v[1] - width//2,
                              v[0] + width//2,
                              v[1] + width//2),
                             fill=color)

        return mask

    width, height = img.size
    mask = generate_mask(height, width, img)
    return mask



def brush_stroke_mask_video(imgs, color=(255,255,255)):
    min_num_vertex = 8
    max_num_vertex = 28
    mean_angle = 2 * math.pi / 5
    angle_range = 2 * math.pi / 12
    # training large mask ratio (training setting)
    min_width = 30
    max_width = 70
    # very large mask ratio (test setting and refine after 200k)
    # min_width = 80
    # max_width = 120
    def generate_mask(H, W, imgs=None):
        average_radius = math.sqrt(H*H+W*W) / 8
        # mask = Image.new('RGB', (W, H), 0)
        # if img is not None:
        #     mask = img # Image.fromarray(img)

        for _ in range(np.random.randint(1, 4)):
            num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
            angle_min = mean_angle - np.random.uniform(0, angle_range)
            angle_max = mean_angle + np.random.uniform(0, angle_range)
            angles = []
            vertex = []
            for i in range(num_vertex):
                if i % 2 == 0:
                    angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
                else:
                    angles.append(np.random.uniform(angle_min, angle_max))

            h, w = imgs[0].size
            vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
            for i in range(num_vertex):
                r = np.clip(
                    np.random.normal(loc=average_radius, scale=average_radius//2),
                    0, 2*average_radius)
                new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
                new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
                vertex.append((int(new_x), int(new_y)))

            width_ = int(np.random.uniform(min_width, max_width))
            for img in imgs:
                draw = ImageDraw.Draw(img)
                draw.line(vertex, fill=color, width=width_)
                for v in vertex:
                    draw.ellipse((v[0] - width_//2,
                                 v[1] - width_//2,
                                 v[0] + width_//2,
                                 v[1] + width_//2),
                                 fill=color)

        return imgs

    width, height = imgs[0].size
    mask = generate_mask(height, width, imgs)
    return mask




def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
    """Generate a random free form mask with configuration.
    Args:
        config: Config should have configuration including IMG_SHAPES,
            VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
    Returns:
        tuple: (top, left, height, width)
    Link:
        https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
    """
    height = shape[0]
    width = shape[1]
    mask = np.zeros((height, width), np.float32)
    times = np.random.randint(times-5, times)
    for i in range(times):
        start_x = np.random.randint(width)
        start_y = np.random.randint(height)
        for j in range(1 + np.random.randint(5)):
            angle = 0.01 + np.random.randint(max_angle)
            if i % 2 == 0:
                angle = 2 * 3.1415926 - angle
            length = 10 + np.random.randint(max_len-20, max_len)
            brush_w = 5 + np.random.randint(max_width-30, max_width)
            end_x = (start_x + length * np.sin(angle)).astype(np.int32)
            end_y = (start_y + length * np.cos(angle)).astype(np.int32)
            cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
            start_x, start_y = end_x, end_y
    return mask.astype(np.float32)

================================================
FILE: basicsr/data/degradations.py
================================================
import cv2
import math
import numpy as np
import random
import torch
from scipy import special
from scipy.stats import multivariate_normal
from torchvision.transforms.functional import rgb_to_grayscale

# -------------------------------------------------------------------- #
# --------------------------- blur kernels --------------------------- #
# -------------------------------------------------------------------- #


# --------------------------- util functions --------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
    """Calculate the rotated sigma matrix (two dimensional matrix).

    Args:
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.

    Returns:
        ndarray: Rotated sigma matrix.
    """
    d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
    u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
    return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))


def mesh_grid(kernel_size):
    """Generate the mesh grid, centering at zero.

    Args:
        kernel_size (int):

    Returns:
        xy (ndarray): with the shape (kernel_size, kernel_size, 2)
        xx (ndarray): with the shape (kernel_size, kernel_size)
        yy (ndarray): with the shape (kernel_size, kernel_size)
    """
    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
                                                                           1))).reshape(kernel_size, kernel_size, 2)
    return xy, xx, yy


def pdf2(sigma_matrix, grid):
    """Calculate PDF of the bivariate Gaussian distribution.

    Args:
        sigma_matrix (ndarray): with the shape (2, 2)
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.

    Returns:
        kernel (ndarrray): un-normalized kernel.
    """
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
    return kernel


def cdf2(d_matrix, grid):
    """Calculate the CDF of the standard bivariate Gaussian distribution.
        Used in skewed Gaussian distribution.

    Args:
        d_matrix (ndarrasy): skew matrix.
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.

    Returns:
        cdf (ndarray): skewed cdf.
    """
    rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
    grid = np.dot(grid, d_matrix)
    cdf = rv.cdf(grid)
    return cdf


def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
    """Generate a bivariate isotropic or anisotropic Gaussian kernel.

    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.

    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
        isotropic (bool):

    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    if isotropic:
        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
    else:
        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
    """Generate a bivariate generalized Gaussian kernel.

    ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``

    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.

    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None

    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    if isotropic:
        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
    else:
        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
    """Generate a plateau-like anisotropic kernel.

    1 / (1+x^(beta))

    Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution

    In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.

    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None

    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    if isotropic:
        sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
    else:
        sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
    kernel = kernel / np.sum(kernel)
    return kernel


def random_bivariate_Gaussian(kernel_size,
                              sigma_x_range,
                              sigma_y_range,
                              rotation_range,
                              noise_range=None,
                              isotropic=True):
    """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.

    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.

    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        noise_range(tuple, optional): multiplicative kernel noise,
            [0.75, 1.25]. Default: None

    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    if isotropic is False:
        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    else:
        sigma_y = sigma_x
        rotation = 0

    kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    return kernel


def random_bivariate_generalized_Gaussian(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          beta_range,
                                          noise_range=None,
                                          isotropic=True):
    """Randomly generate bivariate generalized Gaussian kernels.

    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.

    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise,
            [0.75, 1.25]. Default: None

    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    if isotropic is False:
        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    else:
        sigma_y = sigma_x
        rotation = 0

    # assume beta_range[0] < 1 < beta_range[1]
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    return kernel


def random_bivariate_plateau(kernel_size,
                             sigma_x_range,
                             sigma_y_range,
                             rotation_range,
                             beta_range,
                             noise_range=None,
                             isotropic=True):
    """Randomly generate bivariate plateau kernels.

    In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.

    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi/2, math.pi/2]
        beta_range (tuple): [1, 4]
        noise_range(tuple, optional): multiplicative kernel noise,
            [0.75, 1.25]. Default: None

    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    if isotropic is False:
        assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
        assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
        sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
        rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    else:
        sigma_y = sigma_x
        rotation = 0

    # TODO: this may be not proper
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)

    return kernel


def random_mixed_kernels(kernel_list,
                         kernel_prob,
                         kernel_size=21,
                         sigma_x_range=(0.6, 5),
                         sigma_y_range=(0.6, 5),
                         rotation_range=(-math.pi, math.pi),
                         betag_range=(0.5, 8),
                         betap_range=(0.5, 8),
                         noise_range=None):
    """Randomly generate mixed kernels.

    Args:
        kernel_list (tuple): a list name of kernel types,
            support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
            'plateau_aniso']
        kernel_prob (tuple): corresponding kernel probability for each
            kernel type
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise,
            [0.75, 1.25]. Default: None

    Returns:
        kernel (ndarray):
    """
    kernel_type = random.choices(kernel_list, kernel_prob)[0]
    if kernel_type == 'iso':
        kernel = random_bivariate_Gaussian(
            kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
    elif kernel_type == 'aniso':
        kernel = random_bivariate_Gaussian(
            kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
    elif kernel_type == 'generalized_iso':
        kernel = random_bivariate_generalized_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            betag_range,
            noise_range=noise_range,
            isotropic=True)
    elif kernel_type == 'generalized_aniso':
        kernel = random_bivariate_generalized_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            betag_range,
            noise_range=noise_range,
            isotropic=False)
    elif kernel_type == 'plateau_iso':
        kernel = random_bivariate_plateau(
            kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
    elif kernel_type == 'plateau_aniso':
        kernel = random_bivariate_plateau(
            kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
    return kernel


np.seterr(divide='ignore', invalid='ignore')


def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
    """2D sinc filter

    Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter

    Args:
        cutoff (float): cutoff frequency in radians (pi is max)
        kernel_size (int): horizontal and vertical size, must be odd.
        pad_to (int): pad kernel size to desired size, must be odd or zero.
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    kernel = np.fromfunction(
        lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
            (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
                (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
    kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
    kernel = kernel / np.sum(kernel)
    if pad_to > kernel_size:
        pad_size = (pad_to - kernel_size) // 2
        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
    return kernel


# ------------------------------------------------------------- #
# --------------------------- noise --------------------------- #
# ------------------------------------------------------------- #

# ----------------------- Gaussian Noise ----------------------- #


def generate_gaussian_noise(img, sigma=10, gray_noise=False):
    """Generate Gaussian noise.

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        sigma (float): Noise scale (measured in range 255). Default: 10.

    Returns:
        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
            float32.
    """
    if gray_noise:
        noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
        noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
    else:
        noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
    return noise


def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
    """Add Gaussian noise.

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        sigma (float): Noise scale (measured in range 255). Default: 10.

    Returns:
        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
            float32.
    """
    noise = generate_gaussian_noise(img, sigma, gray_noise)
    out = img + noise
    if clip and rounds:
        out = np.clip((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = np.clip(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
    """Add Gaussian noise (PyTorch version).

    Args:
        img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
        scale (float | Tensor): Noise scale. Default: 1.0.

    Returns:
        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
            float32.
    """
    b, _, h, w = img.size()
    if not isinstance(sigma, (float, int)):
        sigma = sigma.view(img.size(0), 1, 1, 1)
    if isinstance(gray_noise, (float, int)):
        cal_gray_noise = gray_noise > 0
    else:
        gray_noise = gray_noise.view(b, 1, 1, 1)
        cal_gray_noise = torch.sum(gray_noise) > 0

    if cal_gray_noise:
        noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
        noise_gray = noise_gray.view(b, 1, h, w)

    # always calculate color noise
    noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.

    if cal_gray_noise:
        noise = noise * (1 - gray_noise) + noise_gray * gray_noise
    return noise


def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
    """Add Gaussian noise (PyTorch version).

    Args:
        img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
        scale (float | Tensor): Noise scale. Default: 1.0.

    Returns:
        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
            float32.
    """
    noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
    out = img + noise
    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


# ----------------------- Random Gaussian Noise ----------------------- #
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
    if np.random.uniform() < gray_prob:
        gray_noise = True
    else:
        gray_noise = False
    return generate_gaussian_noise(img, sigma, gray_noise)


def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
    noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
    out = img + noise
    if clip and rounds:
        out = np.clip((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = np.clip(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
    sigma = torch.rand(
        img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
    gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
    gray_noise = (gray_noise < gray_prob).float()
    return generate_gaussian_noise_pt(img, sigma, gray_noise)


def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
    noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
    out = img + noise
    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


# ----------------------- Poisson (Shot) Noise ----------------------- #


def generate_poisson_noise(img, scale=1.0, gray_noise=False):
    """Generate poisson noise.

    Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        scale (float): Noise scale. Default: 1.0.
        gray_noise (bool): Whether generate gray noise. Default: False.

    Returns:
        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
            float32.
    """
    if gray_noise:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # round and clip image for counting vals correctly
    img = np.clip((img * 255.0).round(), 0, 255) / 255.
    vals = len(np.unique(img))
    vals = 2**np.ceil(np.log2(vals))
    out = np.float32(np.random.poisson(img * vals) / float(vals))
    noise = out - img
    if gray_noise:
        noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
    return noise * scale


def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
    """Add poisson noise.

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        scale (float): Noise scale. Default: 1.0.
        gray_noise (bool): Whether generate gray noise. Default: False.

    Returns:
        (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
            float32.
    """
    noise = generate_poisson_noise(img, scale, gray_noise)
    out = img + noise
    if clip and rounds:
        out = np.clip((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = np.clip(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
    """Generate a batch of poisson noise (PyTorch version)

    Args:
        img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
        scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
            Default: 1.0.
        gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
            0 for False, 1 for True. Default: 0.

    Returns:
        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
            float32.
    """
    b, _, h, w = img.size()
    if isinstance(gray_noise, (float, int)):
        cal_gray_noise = gray_noise > 0
    else:
        gray_noise = gray_noise.view(b, 1, 1, 1)
        cal_gray_noise = torch.sum(gray_noise) > 0
    if cal_gray_noise:
        img_gray = rgb_to_grayscale(img, num_output_channels=1)
        # round and clip image for counting vals correctly
        img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
        # use for-loop to get the unique values for each sample
        vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
        vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
        vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
        out = torch.poisson(img_gray * vals) / vals
        noise_gray = out - img_gray
        noise_gray = noise_gray.expand(b, 3, h, w)

    # always calculate color noise
    # round and clip image for counting vals correctly
    img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
    # use for-loop to get the unique values for each sample
    vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
    vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
    vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
    out = torch.poisson(img * vals) / vals
    noise = out - img
    if cal_gray_noise:
        noise = noise * (1 - gray_noise) + noise_gray * gray_noise
    if not isinstance(scale, (float, int)):
        scale = scale.view(b, 1, 1, 1)
    return noise * scale


def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
    """Add poisson noise to a batch of images (PyTorch version).

    Args:
        img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
        scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
            Default: 1.0.
        gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
            0 for False, 1 for True. Default: 0.

    Returns:
        (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
            float32.
    """
    noise = generate_poisson_noise_pt(img, scale, gray_noise)
    out = img + noise
    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


# ----------------------- Random Poisson (Shot) Noise ----------------------- #


def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
    scale = np.random.uniform(scale_range[0], scale_range[1])
    if np.random.uniform() < gray_prob:
        gray_noise = True
    else:
        gray_noise = False
    return generate_poisson_noise(img, scale, gray_noise)


def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
    noise = random_generate_poisson_noise(img, scale_range, gray_prob)
    out = img + noise
    if clip and rounds:
        out = np.clip((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = np.clip(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
    scale = torch.rand(
        img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
    gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
    gray_noise = (gray_noise < gray_prob).float()
    return generate_poisson_noise_pt(img, scale, gray_noise)


def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
    noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
    out = img + noise
    if clip and rounds:
        out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
    elif clip:
        out = torch.clamp(out, 0, 1)
    elif rounds:
        out = (out * 255.0).round() / 255.
    return out


# ------------------------------------------------------------------------ #
# --------------------------- JPEG compression --------------------------- #
# ------------------------------------------------------------------------ #


def add_jpg_compression(img, quality=90):
    """Add JPG compression artifacts.

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        quality (float): JPG compression quality. 0 for lowest quality, 100 for
            best quality. Default: 90.

    Returns:
        (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
            float32.
    """
    img = np.clip(img, 0, 1)
    encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
    _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
    img = np.float32(cv2.imdecode(encimg, 1)) / 255.
    return img


def random_add_jpg_compression(img, quality_range=(90, 100)):
    """Randomly add JPG compression artifacts.

    Args:
        img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
        quality_range (tuple[float] | list[float]): JPG compression quality
            range. 0 for lowest quality, 100 for best quality.
            Default: (90, 100).

    Returns:
        (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
            float32.
    """
    quality = np.random.uniform(quality_range[0], quality_range[1])
    return add_jpg_compression(img, quality)


================================================
FILE: basicsr/data/gaussian_kernels.py
================================================
import math
import numpy as np
import random
from scipy.ndimage.interpolation import shift
from scipy.stats import multivariate_normal


def sigma_matrix2(sig_x, sig_y, theta):
    """Calculate the rotated sigma matrix (two dimensional matrix).
    Args:
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
    Returns:
        ndarray: Rotated sigma matrix.
    """
    D = np.array([[sig_x**2, 0], [0, sig_y**2]])
    U = np.array([[np.cos(theta), -np.sin(theta)],
                  [np.sin(theta), np.cos(theta)]])
    return np.dot(U, np.dot(D, U.T))


def mesh_grid(kernel_size):
    """Generate the mesh grid, centering at zero.
    Args:
        kernel_size (int):
    Returns:
        xy (ndarray): with the shape (kernel_size, kernel_size, 2)
        xx (ndarray): with the shape (kernel_size, kernel_size)
        yy (ndarray): with the shape (kernel_size, kernel_size)
    """
    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
                    yy.reshape(kernel_size * kernel_size,
                               1))).reshape(kernel_size, kernel_size, 2)
    return xy, xx, yy


def pdf2(sigma_matrix, grid):
    """Calculate PDF of the bivariate Gaussian distribution.
    Args:
        sigma_matrix (ndarray): with the shape (2, 2)
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.
    Returns:
        kernel (ndarrray): un-normalized kernel.
    """
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
    return kernel


def cdf2(D, grid):
    """Calculate the CDF of the standard bivariate Gaussian distribution.
        Used in skewed Gaussian distribution.
    Args:
        D (ndarrasy): skew matrix.
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.
    Returns:
        cdf (ndarray): skewed cdf.
    """
    rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
    grid = np.dot(grid, D)
    cdf = rv.cdf(grid)
    return cdf


def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
    """Generate a bivariate skew Gaussian kernel.
        Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        D (ndarrasy): skew matrix.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    .. _A multivariate skew normal distribution:
        https://www.sciencedirect.com/science/article/pii/S0047259X03001313
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    pdf = pdf2(sigma_matrix, grid)
    cdf = cdf2(D, grid)
    kernel = pdf * cdf
    kernel = kernel / np.sum(kernel)
    return kernel


def mass_center_shift(kernel_size, kernel):
    """Calculate the shift of the mass center of a kenrel.
    Args:
        kernel_size (int):
        kernel (ndarray): normalized kernel.
    Returns:
        delta_h (float):
        delta_w (float):
    """
    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
    col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
    delta_h = np.dot(row_sum, ax)
    delta_w = np.dot(col_sum, ax)
    return delta_h, delta_w


def bivariate_skew_Gaussian_center(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   D,
                                   grid=None):
    """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        D (ndarrasy): skew matrix.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): centered and normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_anisotropic_Gaussian(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   grid=None):
    """Generate a bivariate anisotropic Gaussian kernel.
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
    """Generate a bivariate isotropic Gaussian kernel.
    Args:
        kernel_size (int):
        sig (float):
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_generalized_Gaussian(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   beta,
                                   grid=None):
    """Generate a bivariate generalized Gaussian kernel.
        Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
        by Pascal et. al (2013).
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
        https://arxiv.org/abs/1302.6498
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(
        -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
    """Generate a plateau-like anisotropic kernel.
    1 / (1+x^(beta))
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.reciprocal(
        np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
    """Generate a plateau-like isotropic kernel.
    1 / (1+x^(beta))
    Args:
        kernel_size (int):
        sig (float):
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.reciprocal(
        np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
    kernel = kernel / np.sum(kernel)
    return kernel


def random_bivariate_skew_Gaussian_center(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate skew Gaussian kernels at center.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])

    sigma_max = np.max([sigma_x, sigma_y])
    thres = 3 / sigma_max
    D = [[np.random.uniform(-thres, thres),
          np.random.uniform(-thres, thres)],
         [np.random.uniform(-thres, thres),
          np.random.uniform(-thres, thres)]]

    kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
                                            rotation, D)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, D
    else:
        return kernel


def random_bivariate_anisotropic_Gaussian(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate anisotropic Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])

    kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
                                            rotation)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation
    else:
        return kernel


def random_bivariate_isotropic_Gaussian(kernel_size,
                                        sigma_range,
                                        noise_range=None,
                                        strict=False):
    """Randomly generate bivariate isotropic Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_range (tuple): [0.6, 5]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])

    kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma
    else:
        return kernel


def random_bivariate_generalized_Gaussian(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          beta_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate generalized Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
                                            rotation, beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, beta
    else:
        return kernel


def random_bivariate_plateau_type1(kernel_size,
                                   sigma_x_range,
                                   sigma_y_range,
                                   rotation_range,
                                   beta_range,
                                   noise_range=None,
                                   strict=False):
    """Randomly generate bivariate plateau type1 kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi/2, math.pi/2]
        beta_range (tuple): [1, 4]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
                                     beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, beta
    else:
        return kernel


def random_bivariate_plateau_type1_iso(kernel_size,
                                       sigma_range,
                                       beta_range,
                                       noise_range=None,
                                       strict=False):
    """Randomly generate bivariate plateau type1 kernels (iso).
    Args:
        kernel_size (int):
        sigma_range (tuple): [0.6, 5]
        beta_range (tuple): [1, 4]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
    beta = np.random.uniform(beta_range[0], beta_range[1])

    kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma, beta
    else:
        return kernel


def random_mixed_kernels(kernel_list,
                         kernel_prob,
                         kernel_size=21,
                         sigma_x_range=[0.6, 5],
                         sigma_y_range=[0.6, 5],
                         rotation_range=[-math.pi, math.pi],
                         beta_range=[0.5, 8],
                         noise_range=None):
    """Randomly generate mixed kernels.
    Args:
        kernel_list (tuple): a list name of kenrel types,
            support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
        kernel_prob (tuple): corresponding kernel probability for each kernel type
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    kernel_type = random.choices(kernel_list, kernel_prob)[0]
    if kernel_type == 'iso':
        kernel = random_bivariate_isotropic_Gaussian(
            kernel_size, sigma_x_range, noise_range=noise_range)
    elif kernel_type == 'aniso':
        kernel = random_bivariate_anisotropic_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            noise_range=noise_range)
    elif kernel_type == 'skew':
        kernel = random_bivariate_skew_Gaussian_center(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            noise_range=noise_range)
    elif kernel_type == 'generalized':
        kernel = random_bivariate_generalized_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            beta_range,
            noise_range=noise_range)
    elif kernel_type == 'plateau_iso':
        kernel = random_bivariate_plateau_type1_iso(
            kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
    elif kernel_type == 'plateau_aniso':
        kernel = random_bivariate_plateau_type1(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            beta_range,
            noise_range=noise_range)
    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    return kernel


def show_one_kernel():
    import matplotlib.pyplot as plt
    kernel_size = 21

    # bivariate skew Gaussian
    D = [[0, 0], [0, 0]]
    D = [[3 / 4, 0], [0, 0.5]]
    kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
    # bivariate anisotropic Gaussian
    kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
    # bivariate anisotropic Gaussian
    kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
    # bivariate generalized Gaussian
    kernel = bivariate_generalized_Gaussian(
        kernel_size, 2, 4, -math.pi / 4, beta=4)

    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    print(delta_h, delta_w)

    fig, axs = plt.subplots(nrows=2, ncols=2)
    # axs.set_axis_off()
    ax = axs[0][0]
    im = ax.matshow(kernel, cmap='jet', origin='upper')
    fig.colorbar(im, ax=ax)

    # image
    ax = axs[0][1]
    kernel_vis = kernel - np.min(kernel)
    kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
    ax.imshow(kernel_vis, interpolation='nearest')

    _, xx, yy = mesh_grid(kernel_size)
    # contour
    ax = axs[1][0]
    CS = ax.contour(xx, yy, kernel, origin='upper')
    ax.clabel(CS, inline=1, fontsize=3)

    # contourf
    ax = axs[1][1]
    kernel = kernel / np.max(kernel)
    p = ax.contourf(
        xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
    fig.colorbar(p)

    plt.show()


def show_plateau_kernel():
    import matplotlib.pyplot as plt
    kernel_size = 21

    kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
    kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
    kernel_gau = bivariate_generalized_Gaussian(
        kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    print(delta_h, delta_w)

    # kernel_slice = kernel[10, :]
    # kernel_gau_slice = kernel_gau[10, :]
    # kernel_norm_slice = kernel_norm[10, :]
    # fig, ax = plt.subplots()
    # t = list(range(1, 22))

    # ax.plot(t, kernel_gau_slice)
    # ax.plot(t, kernel_slice)
    # ax.plot(t, kernel_norm_slice)

    # t = np.arange(0, 10, 0.1)
    # y = np.exp(-0.5 * t)
    # y2 = np.reciprocal(1 + t)
    # print(t.shape)
    # print(y.shape)
    # ax.plot(t, y)
    # ax.plot(t, y2)
    # plt.show()

    fig, axs = plt.subplots(nrows=2, ncols=2)
    # axs.set_axis_off()
    ax = axs[0][0]
    im = ax.matshow(kernel, cmap='jet', origin='upper')
    fig.colorbar(im, ax=ax)

    # image
    ax = axs[0][1]
    kernel_vis = kernel - np.min(kernel)
    kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
    ax.imshow(kernel_vis, interpolation='nearest')

    _, xx, yy = mesh_grid(kernel_size)
    # contour
    ax = axs[1][0]
    CS = ax.contour(xx, yy, kernel, origin='upper')
    ax.clabel(CS, inline=1, fontsize=3)

    # contourf
    ax = axs[1][1]
    kernel = kernel / np.max(kernel)
    p = ax.contourf(
        xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
    fig.colorbar(p)

    plt.show()


================================================
FILE: basicsr/data/inpainting_dataset.py
================================================
import os
import random
from pathlib import Path

from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
                                       random_mixed_kernels)
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, brush_stroke_mask_video, random_ff_mask
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data

@DATASET_REGISTRY.register()
class InpaintingDataset(data.Dataset):
    def __init__(self, opt):
        super(InpaintingDataset, self).__init__()
        self.opt = opt
        self.gt_root = Path(opt['dataroot_gt'])

        self.num_frame = opt['video_length'] # 5
        self.scale = opt['scale'] # [1, 4]
        self.need_align = opt.get('need_align', False) # False
        self.normalize = opt.get('normalize', False) # True

        self.keys = []
        with open(opt['global_meta_info_file'], 'r') as fin:
            for line in fin:
                real_clip_path = '/'.join(line.split('/')[:-1])
                clip_length = int(line.split('/')[-1])
                self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])

        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.is_lmdb = False
        if self.io_backend_opt['type'] == 'lmdb':
            self.is_lmdb = True
            self.io_backend_opt['db_paths'] = [self.gt_root]
            self.io_backend_opt['client_keys'] = ['gt']

        # temporal augmentation configs
        self.interval_list = opt['interval_list'] # [1]
        self.random_reverse = opt['random_reverse']
        interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
        logger = get_root_logger()
        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
                    f'random reverse is {self.random_reverse}.')

        # degradations
        # blur
        self.blur_kernel_size = opt['blur_kernel_size'] # 21
        self.kernel_list = opt['kernel_list']           # ['iso', 'aniso']
        self.kernel_prob = opt['kernel_prob']           # [0.5, 0.5]  
        self.blur_x_sigma = opt['blur_x_sigma']         # [0.2, 3]
        self.blur_y_sigma = opt['blur_y_sigma']         # [0.2, 3]
        # noise
        self.noise_range = opt['noise_range']           # [0, 25] 
        # resize
        self.resize_prob = opt['resize_prob']           # [0.25, 0.25, 0.5]
        # crf
        self.crf_range = opt['crf_range']               # [10, 30]
        # codec
        self.vcodec = opt['vcodec']                     # ['libx264']
        self.vcodec_prob = opt['vcodec_prob']           # [1]

        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
                    f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
                    f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
        logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
        logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
        logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')

        if self.need_align:
            self.dataroot_meta_info = opt['dataroot_meta_info']
            self.face_aligner = FaceAligner(
                upscale_factor=1,
                face_size=512,
                crop_ratio=(1, 1),
                det_model='retinaface_resnet50',
                save_ext='png',
                use_parse=True)

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        key = self.keys[index]
        real_clip_path = '/'.join(key.split('/')[:-2])
        clip_length = int(key.split('/')[-2])
        frame_idx = int(key.split('/')[-1])
        clip_name = real_clip_path.split('/')[-1]

        if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
        elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
        else:
            paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))

        # determine the neighboring frames
        interval = random.choice(self.interval_list)

        # exceed the length, re-select a new clip
        while (clip_length - self.num_frame * interval) < 0:
            interval = random.choice(self.interval_list)

        # ensure not exceeding the borders
        start_frame_idx = frame_idx - self.num_frame // 2 * interval
        end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval

        while (start_frame_idx < 0) or (end_frame_idx > clip_length):
            frame_idx = random.randint(self.num_frame // 2 * interval,
                                       clip_length - self.num_frame // 2 * interval)
            start_frame_idx = frame_idx - self.num_frame // 2 * interval
            end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
        neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))

        # random reverse
        if self.random_reverse and random.random() < 0.5:
            neighbor_list.reverse()

        assert len(neighbor_list) == self.num_frame, (
            f'Wrong length of neighbor list: {len(neighbor_list)}')

        # get the neighboring GT frames
        img_gts = []

        need_align = False
        if self.need_align:
            clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
            if os.path.exists(clip_info_path):
                need_align = True
                clip_info = []
                with open(clip_info_path, 'r', encoding='utf-8') as fin:
                    for line in fin:
                        line = line.strip()
                        clip_info.append(line)

        for neighbor in neighbor_list:
            img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])

            img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
            img_gts.append(img_gt)

        # augmentation - flip, rotate
        img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False

        # ------------- generate inpaint frames --------------#
        img_lqs = img_gts
        img_lqs = [Image.fromarray((_ * 255).astype('uint8')) for _ in img_lqs]
        img_lqs = brush_stroke_mask_video(img_lqs)
        img_lqs = [np.array(_) / 255. for _ in img_lqs]

        # ------------ Align -------------#
        if need_align:
            align_lqs, align_gts = [], []
            for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
                landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
                landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
                self.face_aligner.clean_all()

                # align and warp each face
                img_lq, img_gt = self.face_aligner.align_pair_face(img_lq, img_gt, landmarks)
                align_lqs.append(img_lq)
                align_gts.append(img_gt)
            img_lqs, img_gts = align_lqs, align_gts

        img_gts = img2tensor(img_gts)
        img_lqs = img2tensor(img_lqs)
        img_gts = torch.stack(img_gts, dim=0)
        img_lqs = torch.stack(img_lqs, dim=0)

        if self.normalize:
            normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
            normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)

        return {'in': img_lqs, 'gt': img_gts, 'key': key}

    def __len__(self):
        return len(self.keys)


================================================
FILE: basicsr/data/paired_image_dataset.py
================================================
from torch.utils import data as data
from torchvision.transforms.functional import normalize

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
    """Paired image dataset for image restoration.

    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
    GT image pairs.

    There are three modes:
    1. 'lmdb': Use lmdb files.
        If opt['io_backend'] == lmdb.
    2. 'meta_info_file': Use meta information file to generate paths.
        If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
    3. 'folder': Scan folders to generate paths.
        The rest.

    Args:
        opt (dict): Config for train datasets. It contains the following keys:
            dataroot_gt (str): Data root path for gt.
            dataroot_lq (str): Data root path for lq.
            meta_info_file (str): Path for meta information file.
            io_backend (dict): IO backend type and other kwarg.
            filename_tmpl (str): Template for each filename. Note that the
                template excludes the file extension. Default: '{}'.
            gt_size (int): Cropped patched size for gt patches.
            use_flip (bool): Use horizontal flips.
            use_rot (bool): Use rotation (use vertical flip and transposing h
                and w for implementation).

            scale (bool): Scale, which will be added automatically.
            phase (str): 'train' or 'val'.
    """

    def __init__(self, opt):
        super(PairedImageDataset, self).__init__()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.mean = opt['mean'] if 'mean' in opt else None
        self.std = opt['std'] if 'std' in opt else None

        self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
        if 'filename_tmpl' in opt:
            self.filename_tmpl = opt['filename_tmpl']
        else:
            self.filename_tmpl = '{}'

        if self.io_backend_opt['type'] == 'lmdb':
            self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
            self.io_backend_opt['client_keys'] = ['lq', 'gt']
            self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
        elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
            self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
                                                          self.opt['meta_info_file'], self.filename_tmpl)
        else:
            self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        scale = self.opt['scale']

        # Load gt and lq images. Dimension order: HWC; channel order: BGR;
        # image range: [0, 1], float32.
        gt_path = self.paths[index]['gt_path']
        img_bytes = self.file_client.get(gt_path, 'gt')
        img_gt = imfrombytes(img_bytes, float32=True)
        lq_path = self.paths[index]['lq_path']
        img_bytes = self.file_client.get(lq_path, 'lq')
        img_lq = imfrombytes(img_bytes, float32=True)

        # augmentation for training
        if self.opt['phase'] == 'train':
            gt_size = self.opt['gt_size']
            # random crop
            img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
            # flip, rotation
            img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])

        # TODO: color space transform
        # BGR to RGB, HWC to CHW, numpy to tensor
        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
        # normalize
        if self.mean is not None or self.std is not None:
            normalize(img_lq, self.mean, self.std, inplace=True)
            normalize(img_gt, self.mean, self.std, inplace=True)

        return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

    def __len__(self):
        return len(self.paths)


================================================
FILE: basicsr/data/prefetch_dataloader.py
================================================
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader


class PrefetchGenerator(threading.Thread):
    """A general prefetch generator.

    Ref:
    https://stackoverflow.com/questions/7323664/python-generator-pre-fetch

    Args:
        generator: Python generator.
        num_prefetch_queue (int): Number of prefetch queue.
    """

    def __init__(self, generator, num_prefetch_queue):
        threading.Thread.__init__(self)
        self.queue = Queue.Queue(num_prefetch_queue)
        self.generator = generator
        self.daemon = True
        self.start()

    def run(self):
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def __next__(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class PrefetchDataLoader(DataLoader):
    """Prefetch version of dataloader.

    Ref:
    https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#

    TODO:
    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
    ddp.

    Args:
        num_prefetch_queue (int): Number of prefetch queue.
        kwargs (dict): Other arguments for dataloader.
    """

    def __init__(self, num_prefetch_queue, **kwargs):
        self.num_prefetch_queue = num_prefetch_queue
        super(PrefetchDataLoader, self).__init__(**kwargs)

    def __iter__(self):
        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)


class CPUPrefetcher():
    """CPU prefetcher.

    Args:
        loader: Dataloader.
    """

    def __init__(self, loader):
        self.ori_loader = loader
        self.loader = iter(loader)

    def next(self):
        try:
            return next(self.loader)
        except StopIteration:
            return None

    def reset(self):
        self.loader = iter(self.ori_loader)


class CUDAPrefetcher():
    """CUDA prefetcher.

    Ref:
    https://github.com/NVIDIA/apex/issues/304#

    It may consums more GPU memory.

    Args:
        loader: Dataloader.
        opt (dict): Options.
    """

    def __init__(self, loader, opt):
        self.ori_loader = loader
        self.loader = iter(loader)
        self.opt = opt
        self.stream = torch.cuda.Stream()
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)  # self.batch is a dict
        except StopIteration:
            self.batch = None
            return None
        # put tensors to gpu
        with torch.cuda.stream(self.stream):
            for k, v in self.batch.items():
                if torch.is_tensor(v):
                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        self.preload()
        return batch

    def reset(self):
        self.loader = iter(self.ori_loader)
        self.preload()


================================================
FILE: basicsr/data/transforms.py
================================================
import cv2
import random


def mod_crop(img, scale):
    """Mod crop images, used during testing.

    Args:
        img (ndarray): Input image.
        scale (int): Scale factor.

    Returns:
        ndarray: Result image.
    """
    img = img.copy()
    if img.ndim in (2, 3):
        h, w = img.shape[0], img.shape[1]
        h_remainder, w_remainder = h % scale, w % scale
        img = img[:h - h_remainder, :w - w_remainder, ...]
    else:
        raise ValueError(f'Wrong img ndim: {img.ndim}.')
    return img


def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
    """Paired random crop.

    It crops lists of lq and gt images with corresponding locations.

    Args:
        img_gts (list[ndarray] | ndarray): GT images. Note that all images
            should have the same shape. If the input is an ndarray, it will
            be transformed to a list containing itself.
        img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
            should have the same shape. If the input is an ndarray, it will
            be transformed to a list containing itself.
        gt_patch_size (int): GT patch size.
        scale (int): Scale factor.
        gt_path (str): Path to ground-truth.

    Returns:
        list[ndarray] | ndarray: GT images and LQ images. If returned results
            only have one element, just return ndarray.
    """

    if not isinstance(img_gts, list):
        img_gts = [img_gts]
    if not isinstance(img_lqs, list):
        img_lqs = [img_lqs]

    h_lq, w_lq, _ = img_lqs[0].shape
    h_gt, w_gt, _ = img_gts[0].shape
    lq_patch_size = gt_patch_size // scale

    if h_gt != h_lq * scale or w_gt != w_lq * scale:
        raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
                         f'multiplication of LQ ({h_lq}, {w_lq}).')
    if h_lq < lq_patch_size or w_lq < lq_patch_size:
        raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
                         f'({lq_patch_size}, {lq_patch_size}). '
                         f'Please remove {gt_path}.')

    # randomly choose top and left coordinates for lq patch
    top = random.randint(0, h_lq - lq_patch_size)
    left = random.randint(0, w_lq - lq_patch_size)

    # crop lq patch
    img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]

    # crop corresponding gt patch
    top_gt, left_gt = int(top * scale), int(left * scale)
    img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
    if len(img_gts) == 1:
        img_gts = img_gts[0]
    if len(img_lqs) == 1:
        img_lqs = img_lqs[0]
    return img_gts, img_lqs


def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
    """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).

    We use vertical flip and transpose for rotation implementation.
    All the images in the list use the same augmentation.

    Args:
        imgs (list[ndarray] | ndarray): Images to be augmented. If the input
            is an ndarray, it will be transformed to a list.
        hflip (bool): Horizontal flip. Default: True.
        rotation (bool): Ratotation. Default: True.
        flows (list[ndarray]: Flows to be augmented. If the input is an
            ndarray, it will be transformed to a list.
            Dimension is (h, w, 2). Default: None.
        return_status (bool): Return the status of flip and rotation.
            Default: False.

    Returns:
        list[ndarray] | ndarray: Augmented images and flows. If returned
            results only have one element, just return ndarray.

    """
    hflip = hflip and random.random() < 0.5
    vflip = rotation and random.random() < 0.5
    rot90 = rotation and random.random() < 0.5

    def _augment(img):
        if hflip:  # horizontal
            cv2.flip(img, 1, img)
        if vflip:  # vertical
            cv2.flip(img, 0, img)
        if rot90:
            img = img.transpose(1, 0, 2)
        return img

    def _augment_flow(flow):
        if hflip:  # horizontal
            cv2.flip(flow, 1, flow)
            flow[:, :, 0] *= -1
        if vflip:  # vertical
            cv2.flip(flow, 0, flow)
            flow[:, :, 1] *= -1
        if rot90:
            flow = flow.transpose(1, 0, 2)
            flow = flow[:, :, [1, 0]]
        return flow

    if not isinstance(imgs, list):
        imgs = [imgs]
    imgs = [_augment(img) for img in imgs]
    if len(imgs) == 1:
        imgs = imgs[0]

    if flows is not None:
        if not isinstance(flows, list):
            flows = [flows]
        flows = [_augment_flow(flow) for flow in flows]
        if len(flows) == 1:
            flows = flows[0]
        return imgs, flows
    else:
        if return_status:
            return imgs, (hflip, vflip, rot90)
        else:
            return imgs


def img_rotate(img, angle, center=None, scale=1.0):
    """Rotate image.

    Args:
        img (ndarray): Image to be rotated.
        angle (float): Rotation angle in degrees. Positive values mean
            counter-clockwise rotation.
        center (tuple[int]): Rotation center. If the center is None,
            initialize it as the center of the image. Default: None.
        scale (float): Isotropic scale factor. Default: 1.0.
    """
    (h, w) = img.shape[:2]

    if center is None:
        center = (w // 2, h // 2)

    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    rotated_img = cv2.warpAffine(img, matrix, (w, h))
    return rotated_img


================================================
FILE: basicsr/data/vfhq_dataset.py
================================================
import os
import random
from pathlib import Path

from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
                                       random_mixed_kernels)
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data


@DATASET_REGISTRY.register()
class VFHQRealDegradationDatasetNew(data.Dataset):
    """Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN.

    This dataset is adopted in BasicVSR.

    The degradation order is blur+downsample+noise

    Directly read image by cv2. Generate LR images online.
    NOTE: The specific degradation order is blur-noise-downsample-crf-upsample

    The keys are generated from a meta info txt file.

    Key format: subfolder-name/clip-length/frame-name
    Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000"
    GT (gt): Ground-Truth;
    LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
    Args:
        opt (dict): Config for train dataset. It contains the following keys:
            dataroot_gt (str): Data root path for gt.
            dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip.
            global_meta_info_file (str): Path for global meta information file.
            io_backend (dict): IO backend type and other kwarg.
            num_frame (int): Window size for input frames.
            interval_list (list): Interval list for temporal augmentation.
            random_reverse (bool): Random reverse input frames.
            use_flip (bool): Use horizontal flips.
            use_rot (bool): Use rotation (use vertical flip and transposing h
                and w for implementation).
    """

    def __init__(self, opt):
        super(VFHQRealDegradationDatasetNew, self).__init__()
        self.opt = opt
        self.gt_root = Path(opt['dataroot_gt'])

        self.num_frame = opt['video_length'] # 5
        self.scale = opt['scale'] # [1, 4]
        self.need_align = opt.get('need_align', False) # False
        self.normalize = opt.get('normalize', False) # True

        self.keys = []
        with open(opt['global_meta_info_file'], 'r') as fin:
            for line in fin:
                real_clip_path = '/'.join(line.split('/')[:-1])
                clip_length = int(line.split('/')[-1])
                self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])

        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.is_lmdb = False
        if self.io_backend_opt['type'] == 'lmdb':
            self.is_lmdb = True
            self.io_backend_opt['db_paths'] = [self.gt_root]
            self.io_backend_opt['client_keys'] = ['gt']

        # temporal augmentation configs
        self.interval_list = opt['interval_list'] # [1]
        self.random_reverse = opt['random_reverse']
        interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
        logger = get_root_logger()
        logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
                    f'random reverse is {self.random_reverse}.')

        # degradations
        # blur
        self.blur_kernel_size = opt['blur_kernel_size'] # 21
        self.kernel_list = opt['kernel_list']           # ['iso', 'aniso']
        self.kernel_prob = opt['kernel_prob']           # [0.5, 0.5]  
        self.blur_x_sigma = opt['blur_x_sigma']         # [0.2, 3]
        self.blur_y_sigma = opt['blur_y_sigma']         # [0.2, 3]
        # noise
        self.noise_range = opt['noise_range']           # [0, 25] 
        # resize
        self.resize_prob = opt['resize_prob']           # [0.25, 0.25, 0.5]
        # crf
        self.crf_range = opt['crf_range']               # [10, 30]
        # codec
        self.vcodec = opt['vcodec']                     # ['libx264']
        self.vcodec_prob = opt['vcodec_prob']           # [1]

        logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
                    f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
                    f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
        logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
        logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
        logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')

        if self.need_align:
            self.dataroot_meta_info = opt['dataroot_meta_info']
            self.face_aligner = FaceAligner(
                upscale_factor=1,
                face_size=512,
                crop_ratio=(1, 1),
                det_model='retinaface_resnet50',
                save_ext='png',
                use_parse=True)

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(
                self.io_backend_opt.pop('type'), **self.io_backend_opt)

        key = self.keys[index]
        real_clip_path = '/'.join(key.split('/')[:-2])
        clip_length = int(key.split('/')[-2])
        frame_idx = int(key.split('/')[-1])
        clip_name = real_clip_path.split('/')[-1]

        if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
        elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
            paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
        else:
            paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))

        # determine the neighboring frames
        interval = random.choice(self.interval_list)

        # exceed the length, re-select a new clip
        while (clip_length - self.num_frame * interval) < 0:
            interval = random.choice(self.interval_list)

        # ensure not exceeding the borders
        start_frame_idx = frame_idx - self.num_frame // 2 * interval
        end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval

        while (start_frame_idx < 0) or (end_frame_idx > clip_length):
            frame_idx = random.randint(self.num_frame // 2 * interval,
                                       clip_length - self.num_frame // 2 * interval)
            start_frame_idx = frame_idx - self.num_frame // 2 * interval
            end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
        neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))

        # random reverse
        if self.random_reverse and random.random() < 0.5:
            neighbor_list.reverse()

        assert len(neighbor_list) == self.num_frame, (
            f'Wrong length of neighbor list: {len(neighbor_list)}')

        # get the neighboring GT frames
        img_gts = []

        need_align = False
        if self.need_align:
            clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
            if os.path.exists(clip_info_path):
                need_align = True
                clip_info = []
                with open(clip_info_path, 'r', encoding='utf-8') as fin:
                    for line in fin:
                        line = line.strip()
                        clip_info.append(line)

        for neighbor in neighbor_list:
            img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
            if not os.path.exists(img_gt_path):
                img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])

            img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
            img_gts.append(img_gt)

        # augmentation - flip, rotate
        img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False

        # ------------- generate LQ frames --------------#
        # add blur
        kernel = random_mixed_kernels(self.kernel_list,
                                      self.kernel_prob,      # [0.7, 0.3]
                                      self.blur_kernel_size, # 21
                                      self.blur_x_sigma,     # [0.1, 10]
                                      self.blur_y_sigma)     # [0.1, 10]
        img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts]

        # downsample
        ori_height, ori_width = img_gts[0].shape[0:2]
        resize_type = random.choices([cv2.INTER_AREA,
                                      cv2.INTER_LINEAR,
                                      cv2.INTER_CUBIC], self.resize_prob)[0]

        # ensure the resized_height and resized_width are even numbers
        # scale = np.random.uniform(self.scale)
        resized_height = int(ori_height // self.scale) // 2 * 2
        resized_width = int(ori_width // self.scale) // 2 * 2
        img_lqs = [cv2.resize(v, (resized_width, resized_height),
                              interpolation=resize_type) for v in img_lqs]

        # add noise
        img_lqs = [random_add_gaussian_noise(v,
                                             self.noise_range, # [0, 10]
                                             gray_prob=0.5,
                                             clip=True,
                                             rounds=False) for v in img_lqs] # noise_range: [0, 25]

        # ffmpeg
        crf = np.random.randint(self.crf_range[0], self.crf_range[1]) # [18, 25]
        codec = random.choices(self.vcodec, self.vcodec_prob)[0] # 'libx264'

        buf = io.BytesIO()
        with av.open(buf, 'w', 'mp4') as container:
            stream = container.add_stream(codec, rate=1)
            stream.height = resized_height
            stream.width = resized_width
            stream.pix_fmt = 'yuv420p'
            stream.options = {'crf': str(crf)}

            for img_lq in img_lqs:
                img_lq = np.clip(img_lq * 255, 0, 255).astype(np.uint8)
                frame = av.VideoFrame.from_ndarray(img_lq, format='rgb24')
                frame.pict_type = av.video.frame.PictureType.NONE
                for packet in stream.encode(frame):
                    container.mux(packet)

            # Flush stream
            for packet in stream.encode():
                container.mux(packet)

        img_lqs = []
        with av.open(buf, 'r', 'mp4') as container:
            if container.streams.video:
                for frame in container.decode(**{'video': 0}):
                    frame = frame.to_rgb().to_ndarray()
                    frame = cv2.resize(frame, (ori_width, ori_height), interpolation=resize_type) # upsample
                    img_lqs.append(frame / 255.)

        assert len(img_lqs) == len(img_gts), 'Wrong length'
        # ------------ Align -------------#
        if need_align:
            align_lqs, align_gts = [], []
            for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
                landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
                landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
                self.face_aligner.clean_all()
                # align and warp each face
                img_lq, img_gt = self.face_aligner.align_pair_face(
                    img_lq, img_gt, landmarks)
                align_lqs.append(img_lq)
                align_gts.append(img_gt)
            img_lqs, img_gts = align_lqs, align_gts

        img_gts = img2tensor(img_gts)
        img_lqs = img2tensor(img_lqs)
        img_gts = torch.stack(img_gts, dim=0)
        img_lqs = torch.stack(img_lqs, dim=0)

        if self.normalize:
            normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
            normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)

        return {'in': img_lqs, 'gt': img_gts, 'key': key}

    def __len__(self):
        return len(self.keys)


================================================
FILE: basicsr/losses/__init__.py
================================================
from copy import deepcopy

from basicsr.utils import get_root_logger
from basicsr.utils.registry import LOSS_REGISTRY
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
                     gradient_penalty_loss, r1_penalty)

__all__ = [
    'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
    'r1_penalty', 'g_path_regularize'
]


def build_loss(opt):
    """Build loss from options.

    Args:
        opt (dict): Configuration. It must constain:
            type (str): Model type.
    """
    opt = deepcopy(opt)
    loss_type = opt.pop('type')
    loss = LOSS_REGISTRY.get(loss_type)(**opt)
    logger = get_root_logger()
    logger.info(f'Loss [{loss.__class__.__name__}] is created.')
    return loss


================================================
FILE: basicsr/losses/loss_util.py
================================================
import functools
from torch.nn import functional as F


def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are 'none', 'mean' and 'sum'.

    Returns:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    
Download .txt
gitextract_tgnuarg3/

├── .gitignore
├── README.md
├── basicsr/
│   ├── VERSION
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── arcface_arch.py
│   │   ├── arch_util.py
│   │   ├── dir_dist_codeformer_multiscale_arch.py
│   │   ├── rrdbnet_arch.py
│   │   ├── vgg_arch.py
│   │   └── vqgan_arch.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── color_dataset.py
│   │   ├── data_sampler.py
│   │   ├── data_util.py
│   │   ├── degradations.py
│   │   ├── gaussian_kernels.py
│   │   ├── inpainting_dataset.py
│   │   ├── paired_image_dataset.py
│   │   ├── prefetch_dataloader.py
│   │   ├── transforms.py
│   │   └── vfhq_dataset.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── loss_util.py
│   │   └── losses.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── metric_util.py
│   │   └── psnr_ssim.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── codeformer_dirichlet_video_model.py
│   │   ├── lr_scheduler.py
│   │   ├── sr_model.py
│   │   └── vqgan_model.py
│   ├── ops/
│   │   ├── __init__.py
│   │   ├── dcn/
│   │   │   ├── __init__.py
│   │   │   ├── deform_conv.py
│   │   │   └── src/
│   │   │       ├── deform_conv_cuda.cpp
│   │   │       ├── deform_conv_cuda_kernel.cu
│   │   │       └── deform_conv_ext.cpp
│   │   ├── fused_act/
│   │   │   ├── __init__.py
│   │   │   ├── fused_act.py
│   │   │   └── src/
│   │   │       ├── fused_bias_act.cpp
│   │   │       └── fused_bias_act_kernel.cu
│   │   └── upfirdn2d/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── upfirdn2d.cpp
│   │       │   └── upfirdn2d_kernel.cu
│   │       └── upfirdn2d.py
│   ├── setup.py
│   ├── train.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── dist_util.py
│   │   ├── download_util.py
│   │   ├── file_client.py
│   │   ├── img_util.py
│   │   ├── lmdb_util.py
│   │   ├── logger.py
│   │   ├── matlab_functions.py
│   │   ├── misc.py
│   │   ├── options.py
│   │   ├── realesrgan_utils.py
│   │   ├── registry.py
│   │   └── video_util.py
│   └── version.py
├── facelib/
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── align_trans.py
│   │   ├── matlab_cp2tform.py
│   │   ├── retinaface/
│   │   │   ├── retinaface.py
│   │   │   ├── retinaface_net.py
│   │   │   └── retinaface_utils.py
│   │   └── yolov5face/
│   │       ├── __init__.py
│   │       ├── face_detector.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── common.py
│   │       │   ├── experimental.py
│   │       │   ├── yolo.py
│   │       │   ├── yolov5l.yaml
│   │       │   └── yolov5n.yaml
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── autoanchor.py
│   │           ├── datasets.py
│   │           ├── extract_ckpt.py
│   │           ├── general.py
│   │           └── torch_utils.py
│   ├── parsing/
│   │   ├── __init__.py
│   │   ├── bisenet.py
│   │   ├── parsenet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── face_restoration_helper.py
│       ├── face_utils.py
│       └── misc.py
├── options/
│   ├── clip5_bs2_512_align_nofix_multiscale.yaml
│   ├── clip5_bs2_512_align_nofix_multiscale_color.yaml
│   └── clip5_bs2_512_align_nofix_multiscale_inpaint.yaml
├── requirements.txt
├── scripts/
│   ├── inference.py
│   ├── inference_color_and_inpainting.py
│   └── warp_images.py
└── train.sh
Download .txt
SYMBOL INDEX (741 symbols across 75 files)

FILE: basicsr/archs/__init__.py
  function build_network (line 19) | def build_network(opt):

FILE: basicsr/archs/arcface_arch.py
  function conv3x3 (line 5) | def conv3x3(inplanes, outplanes, stride=1):
  class BasicBlock (line 16) | class BasicBlock(nn.Module):
    method __init__ (line 27) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 37) | def forward(self, x):
  class IRBlock (line 56) | class IRBlock(nn.Module):
    method __init__ (line 68) | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se...
    method forward (line 82) | def forward(self, x):
  class Bottleneck (line 103) | class Bottleneck(nn.Module):
    method __init__ (line 114) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 126) | def forward(self, x):
  class SEBlock (line 149) | class SEBlock(nn.Module):
    method __init__ (line 157) | def __init__(self, channel, reduction=16):
    method forward (line 164) | def forward(self, x):
  class ResNetArcFace (line 172) | class ResNetArcFace(nn.Module):
    method __init__ (line 183) | def __init__(self, block, layers, use_se=True):
    method _make_layer (line 214) | def _make_layer(self, block, planes, num_blocks, stride=1):
    method forward (line 229) | def forward(self, x):

FILE: basicsr/archs/arch_util.py
  function default_init_weights (line 18) | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
  function make_layer (line 48) | def make_layer(basic_block, num_basic_block, **kwarg):
  class ResidualBlockNoBN (line 64) | class ResidualBlockNoBN(nn.Module):
    method __init__ (line 79) | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
    method forward (line 89) | def forward(self, x):
  class Upsample (line 95) | class Upsample(nn.Sequential):
    method __init__ (line 103) | def __init__(self, scale, num_feat):
  function flow_warp (line 117) | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', ali...
  function resize_flow (line 151) | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_co...
  function pixel_unshuffle (line 190) | def pixel_unshuffle(x, scale):
  class DCNv2Pack (line 209) | class DCNv2Pack(ModulatedDeformConvPack):
    method forward (line 220) | def forward(self, x, feat):
  function _no_grad_trunc_normal_ (line 239) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  function trunc_normal_ (line 277) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  function _ntuple (line 304) | def _ntuple(n):

FILE: basicsr/archs/dir_dist_codeformer_multiscale_arch.py
  function calc_mean_std (line 16) | def calc_mean_std(feat, eps=1e-5):
  function adaptive_instance_normalization (line 33) | def adaptive_instance_normalization(content_feat, style_feat):
  class PositionEmbeddingSine (line 50) | class PositionEmbeddingSine(nn.Module):
    method __init__ (line 56) | def __init__(self, num_pos_feats=64, temperature=10000, normalize=Fals...
    method forward (line 67) | def forward(self, x, mask=None):
  function _get_activation_fn (line 93) | def _get_activation_fn(activation):
  class TransformerSALayer (line 104) | class TransformerSALayer(nn.Module):
    method __init__ (line 105) | def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, acti...
    method with_pos_embed (line 120) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    method forward (line 123) | def forward(
  class TransformerSALayerTemporal (line 147) | class TransformerSALayerTemporal(nn.Module):
    method __init__ (line 148) | def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, acti...
    method with_pos_embed (line 164) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    method forward (line 167) | def forward(self,
  class Fuse_sft_block (line 196) | class Fuse_sft_block(nn.Module):
    method __init__ (line 197) | def __init__(self, in_ch, out_ch):
    method forward (line 211) | def forward(self, enc_feat, dec_feat, w=1):
  class ExpModule (line 219) | class ExpModule(nn.Module):
    method forward (line 220) | def forward(self, x):
  class MultiScaleFuse (line 224) | class MultiScaleFuse(nn.Module):
    method __init__ (line 225) | def __init__(self):
    method forward (line 232) | def forward(self, s64, s32, s16):
  class TemporalCodeFormerDirDistMultiScale (line 244) | class TemporalCodeFormerDirDistMultiScale(VQAutoEncoder):
    method __init__ (line 245) | def __init__(self,
    method _init_weights (line 335) | def _init_weights(self, module):
    method forward (line 344) | def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):

FILE: basicsr/archs/rrdbnet_arch.py
  class ResidualDenseBlock (line 9) | class ResidualDenseBlock(nn.Module):
    method __init__ (line 19) | def __init__(self, num_feat=64, num_grow_ch=32):
    method forward (line 32) | def forward(self, x):
  class RRDB (line 42) | class RRDB(nn.Module):
    method __init__ (line 52) | def __init__(self, num_feat, num_grow_ch=32):
    method forward (line 58) | def forward(self, x):
  class RRDBNet (line 67) | class RRDBNet(nn.Module):
    method __init__ (line 87) | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_bl...
    method forward (line 105) | def forward(self, x):

FILE: basicsr/archs/vgg_arch.py
  function insert_bn (line 36) | def insert_bn(names):
  class VGGFeatureExtractor (line 55) | class VGGFeatureExtractor(nn.Module):
    method __init__ (line 78) | def __init__(self,
    method forward (line 141) | def forward(self, x):

FILE: basicsr/archs/vqgan_arch.py
  function normalize (line 14) | def normalize(in_channels):
  function swish (line 19) | def swish(x):
  class VectorQuantizer (line 24) | class VectorQuantizer(nn.Module):
    method __init__ (line 25) | def __init__(self, codebook_size, emb_dim, beta):
    method forward (line 33) | def forward(self, z):
    method get_codebook_feat (line 72) | def get_codebook_feat(self, indices, shape):
  class GumbelQuantizer (line 87) | class GumbelQuantizer(nn.Module):
    method __init__ (line 88) | def __init__(self, codebook_size, emb_dim, num_hiddens, straight_throu...
    method forward (line 98) | def forward(self, z):
  class Downsample (line 117) | class Downsample(nn.Module):
    method __init__ (line 118) | def __init__(self, in_channels):
    method forward (line 122) | def forward(self, x):
  class Upsample (line 129) | class Upsample(nn.Module):
    method __init__ (line 130) | def __init__(self, in_channels):
    method forward (line 134) | def forward(self, x):
  class ResBlock (line 141) | class ResBlock(nn.Module):
    method __init__ (line 142) | def __init__(self, in_channels, out_channels=None):
    method forward (line 153) | def forward(self, x_in):
  class AttnBlock (line 167) | class AttnBlock(nn.Module):
    method __init__ (line 168) | def __init__(self, in_channels):
    method forward (line 202) | def forward(self, x):
  class Encoder (line 229) | class Encoder(nn.Module):
    method __init__ (line 230) | def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, ...
    method forward (line 269) | def forward(self, x):
  class Generator (line 276) | class Generator(nn.Module):
    method __init__ (line 277) | def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_re...
    method forward (line 319) | def forward(self, x):
  class VQAutoEncoder (line 327) | class VQAutoEncoder(nn.Module):
    method __init__ (line 328) | def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blo...
    method forward (line 385) | def forward(self, x):
  class VQGANDiscriminator (line 395) | class VQGANDiscriminator(nn.Module):
    method __init__ (line 396) | def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
    method forward (line 433) | def forward(self, x):

FILE: basicsr/data/__init__.py
  function build_dataset (line 25) | def build_dataset(dataset_opt):
  function build_dataloader (line 40) | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sample...
  function worker_init_fn (line 96) | def worker_init_fn(worker_id, num_workers, rank, seed):

FILE: basicsr/data/color_dataset.py
  class ColorizationDataset (line 23) | class ColorizationDataset(data.Dataset):
    method __init__ (line 24) | def __init__(self, opt):
    method __getitem__ (line 92) | def __getitem__(self, index):
    method __len__ (line 191) | def __len__(self):

FILE: basicsr/data/data_sampler.py
  class EnlargedSampler (line 6) | class EnlargedSampler(Sampler):
    method __init__ (line 21) | def __init__(self, dataset, num_replicas, rank, ratio=1):
    method __iter__ (line 29) | def __iter__(self):
    method __len__ (line 44) | def __len__(self):
    method set_epoch (line 47) | def set_epoch(self, epoch):

FILE: basicsr/data/data_util.py
  function read_img_seq (line 13) | def read_img_seq(path, require_mod_crop=False, scale=1):
  function generate_frame_indices (line 37) | def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='...
  function paired_paths_from_lmdb (line 89) | def paired_paths_from_lmdb(folders, keys):
  function paired_paths_from_meta_info_file (line 148) | def paired_paths_from_meta_info_file(folders, keys, meta_info_file, file...
  function paired_paths_from_folder (line 192) | def paired_paths_from_folder(folders, keys, filename_tmpl):
  function paths_from_folder (line 228) | def paths_from_folder(folder):
  function paths_from_lmdb (line 243) | def paths_from_lmdb(folder):
  function generate_gaussian_kernel (line 259) | def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
  function duf_downsample (line 277) | def duf_downsample(x, kernel_size=13, scale=4):
  function brush_stroke_mask (line 310) | def brush_stroke_mask(img, color=(255,255,255)):
  function brush_stroke_mask_video (line 367) | def brush_stroke_mask_video(imgs, color=(255,255,255)):
  function random_ff_mask (line 426) | def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70,...

FILE: basicsr/data/degradations.py
  function sigma_matrix2 (line 16) | def sigma_matrix2(sig_x, sig_y, theta):
  function mesh_grid (line 32) | def mesh_grid(kernel_size):
  function pdf2 (line 50) | def pdf2(sigma_matrix, grid):
  function cdf2 (line 66) | def cdf2(d_matrix, grid):
  function bivariate_Gaussian (line 84) | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isot...
  function bivariate_generalized_Gaussian (line 112) | def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, bet...
  function bivariate_plateau (line 143) | def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None,...
  function random_bivariate_Gaussian (line 176) | def random_bivariate_Gaussian(kernel_size,
  function random_bivariate_generalized_Gaussian (line 220) | def random_bivariate_generalized_Gaussian(kernel_size,
  function random_bivariate_plateau (line 272) | def random_bivariate_plateau(kernel_size,
  function random_mixed_kernels (line 324) | def random_mixed_kernels(kernel_list,
  function circular_lowpass_kernel (line 389) | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
  function generate_gaussian_noise (line 419) | def generate_gaussian_noise(img, sigma=10, gray_noise=False):
  function add_gaussian_noise (line 438) | def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_nois...
  function generate_gaussian_noise_pt (line 460) | def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
  function add_gaussian_noise_pt (line 492) | def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds...
  function random_generate_gaussian_noise (line 515) | def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
  function random_add_gaussian_noise (line 524) | def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, cl...
  function random_generate_gaussian_noise_pt (line 536) | def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_pro...
  function random_add_gaussian_noise_pt (line 544) | def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0,...
  function generate_poisson_noise (line 559) | def generate_poisson_noise(img, scale=1.0, gray_noise=False):
  function add_poisson_noise (line 586) | def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_nois...
  function generate_poisson_noise_pt (line 609) | def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
  function add_poisson_noise_pt (line 657) | def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_n...
  function random_generate_poisson_noise (line 685) | def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
  function random_add_poisson_noise (line 694) | def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, cli...
  function random_generate_poisson_noise_pt (line 706) | def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_pro...
  function random_add_poisson_noise_pt (line 714) | def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, ...
  function add_jpg_compression (line 731) | def add_jpg_compression(img, quality=90):
  function random_add_jpg_compression (line 750) | def random_add_jpg_compression(img, quality_range=(90, 100)):

FILE: basicsr/data/gaussian_kernels.py
  function sigma_matrix2 (line 8) | def sigma_matrix2(sig_x, sig_y, theta):
  function mesh_grid (line 23) | def mesh_grid(kernel_size):
  function pdf2 (line 40) | def pdf2(sigma_matrix, grid):
  function cdf2 (line 54) | def cdf2(D, grid):
  function bivariate_skew_Gaussian (line 70) | def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=No...
  function mass_center_shift (line 96) | def mass_center_shift(kernel_size, kernel):
  function bivariate_skew_Gaussian_center (line 112) | def bivariate_skew_Gaussian_center(kernel_size,
  function bivariate_anisotropic_Gaussian (line 139) | def bivariate_anisotropic_Gaussian(kernel_size,
  function bivariate_isotropic_Gaussian (line 163) | def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
  function bivariate_generalized_Gaussian (line 181) | def bivariate_generalized_Gaussian(kernel_size,
  function bivariate_plateau_type1 (line 213) | def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid...
  function bivariate_plateau_type1_iso (line 237) | def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
  function random_bivariate_skew_Gaussian_center (line 259) | def random_bivariate_skew_Gaussian_center(kernel_size,
  function random_bivariate_anisotropic_Gaussian (line 310) | def random_bivariate_anisotropic_Gaussian(kernel_size,
  function random_bivariate_isotropic_Gaussian (line 354) | def random_bivariate_isotropic_Gaussian(kernel_size,
  function random_bivariate_generalized_Gaussian (line 385) | def random_bivariate_generalized_Gaussian(kernel_size,
  function random_bivariate_plateau_type1 (line 435) | def random_bivariate_plateau_type1(kernel_size,
  function random_bivariate_plateau_type1_iso (line 485) | def random_bivariate_plateau_type1_iso(kernel_size,
  function random_mixed_kernels (line 519) | def random_mixed_kernels(kernel_list,
  function show_one_kernel (line 588) | def show_one_kernel():
  function show_plateau_kernel (line 635) | def show_plateau_kernel():

FILE: basicsr/data/inpainting_dataset.py
  class InpaintingDataset (line 23) | class InpaintingDataset(data.Dataset):
    method __init__ (line 24) | def __init__(self, opt):
    method __getitem__ (line 92) | def __getitem__(self, index):
    method __len__ (line 192) | def __len__(self):

FILE: basicsr/data/paired_image_dataset.py
  class PairedImageDataset (line 11) | class PairedImageDataset(data.Dataset):
    method __init__ (line 42) | def __init__(self, opt):
    method __getitem__ (line 67) | def __getitem__(self, index):
    method __len__ (line 100) | def __len__(self):

FILE: basicsr/data/prefetch_dataloader.py
  class PrefetchGenerator (line 7) | class PrefetchGenerator(threading.Thread):
    method __init__ (line 18) | def __init__(self, generator, num_prefetch_queue):
    method run (line 25) | def run(self):
    method __next__ (line 30) | def __next__(self):
    method __iter__ (line 36) | def __iter__(self):
  class PrefetchDataLoader (line 40) | class PrefetchDataLoader(DataLoader):
    method __init__ (line 55) | def __init__(self, num_prefetch_queue, **kwargs):
    method __iter__ (line 59) | def __iter__(self):
  class CPUPrefetcher (line 63) | class CPUPrefetcher():
    method __init__ (line 70) | def __init__(self, loader):
    method next (line 74) | def next(self):
    method reset (line 80) | def reset(self):
  class CUDAPrefetcher (line 84) | class CUDAPrefetcher():
    method __init__ (line 97) | def __init__(self, loader, opt):
    method preload (line 105) | def preload(self):
    method next (line 117) | def next(self):
    method reset (line 123) | def reset(self):

FILE: basicsr/data/transforms.py
  function mod_crop (line 5) | def mod_crop(img, scale):
  function paired_random_crop (line 25) | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
  function augment (line 80) | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=F...
  function img_rotate (line 147) | def img_rotate(img, angle, center=None, scale=1.0):

FILE: basicsr/data/vfhq_dataset.py
  class VFHQRealDegradationDatasetNew (line 23) | class VFHQRealDegradationDatasetNew(data.Dataset):
    method __init__ (line 53) | def __init__(self, opt):
    method __getitem__ (line 121) | def __getitem__(self, index):
    method __len__ (line 277) | def __len__(self):

FILE: basicsr/losses/__init__.py
  function build_loss (line 14) | def build_loss(opt):

FILE: basicsr/losses/loss_util.py
  function reduce_loss (line 5) | def reduce_loss(loss, reduction):
  function weight_reduce_loss (line 25) | def weight_reduce_loss(loss, weight=None, reduction='mean'):
  function weighted_loss (line 57) | def weighted_loss(loss_func):

FILE: basicsr/losses/losses.py
  function l1_loss (line 17) | def l1_loss(pred, target):
  function mse_loss (line 22) | def mse_loss(pred, target):
  function charbonnier_loss (line 27) | def charbonnier_loss(pred, target, eps=1e-12):
  class L1Loss (line 32) | class L1Loss(nn.Module):
    method __init__ (line 41) | def __init__(self, loss_weight=1.0, reduction='mean'):
    method forward (line 49) | def forward(self, pred, target, weight=None, **kwargs):
  class MSELoss (line 61) | class MSELoss(nn.Module):
    method __init__ (line 70) | def __init__(self, loss_weight=1.0, reduction='mean'):
    method forward (line 78) | def forward(self, pred, target, weight=None, **kwargs):
  class CharbonnierLoss (line 90) | class CharbonnierLoss(nn.Module):
    method __init__ (line 105) | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
    method forward (line 114) | def forward(self, pred, target, weight=None, **kwargs):
  class WeightedTVLoss (line 126) | class WeightedTVLoss(L1Loss):
    method __init__ (line 133) | def __init__(self, loss_weight=1.0):
    method forward (line 136) | def forward(self, pred, weight=None):
  class PerceptualLoss (line 146) | class PerceptualLoss(nn.Module):
    method __init__ (line 169) | def __init__(self,
    method forward (line 199) | def forward(self, x, gt):
    method _gram_mat (line 241) | def _gram_mat(self, x):
  class LPIPSLoss (line 258) | class LPIPSLoss(nn.Module):
    method __init__ (line 259) | def __init__(self,
    method forward (line 275) | def forward(self, pred, target):
  class GANLoss (line 287) | class GANLoss(nn.Module):
    method __init__ (line 299) | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, l...
    method _wgan_loss (line 319) | def _wgan_loss(self, input, target):
    method _wgan_softplus_loss (line 331) | def _wgan_softplus_loss(self, input, target):
    method get_target_label (line 348) | def get_target_label(self, input, target_is_real):
    method forward (line 365) | def forward(self, input, target_is_real, is_disc=False):
  function r1_penalty (line 391) | def r1_penalty(real_pred, real_img):
  function g_path_regularize (line 408) | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
  function gradient_penalty_loss (line 420) | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=No...
  class DirichletKLLoss (line 460) | class DirichletKLLoss(nn.Module):
    method __init__ (line 467) | def __init__(self, loss_weight=1.0, kl_coef=1.1):
    method forward (line 472) | def forward(self, alpha):

FILE: basicsr/metrics/__init__.py
  function calculate_metric (line 9) | def calculate_metric(data, opt):

FILE: basicsr/metrics/metric_util.py
  function reorder_image (line 6) | def reorder_image(img, input_order='HWC'):
  function to_y_channel (line 32) | def to_y_channel(img):

FILE: basicsr/metrics/psnr_ssim.py
  function calculate_psnr (line 9) | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_ch...
  function _ssim (line 49) | def _ssim(img1, img2):
  function calculate_ssim (line 84) | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_ch...

FILE: basicsr/models/__init__.py
  function build_model (line 19) | def build_model(opt):

FILE: basicsr/models/base_model.py
  class BaseModel (line 14) | class BaseModel():
    method __init__ (line 17) | def __init__(self, opt):
    method feed_data (line 24) | def feed_data(self, data):
    method optimize_parameters (line 27) | def optimize_parameters(self):
    method get_current_visuals (line 30) | def get_current_visuals(self):
    method save (line 33) | def save(self, epoch, current_iter):
    method validation (line 37) | def validation(self, dataloader, current_iter, tb_logger, save_img=Fal...
    method model_ema (line 51) | def model_ema(self, decay=0.999):
    method get_current_log (line 60) | def get_current_log(self):
    method model_to_device (line 63) | def model_to_device(self, net):
    method get_optimizer (line 79) | def get_optimizer(self, optim_type, params, lr, **kwargs):
    method setup_schedulers (line 86) | def setup_schedulers(self):
    method get_bare_model (line 99) | def get_bare_model(self, net):
    method print_network (line 108) | def print_network(self, net):
    method _set_lr (line 126) | def _set_lr(self, lr_groups_l):
    method _get_init_lr (line 136) | def _get_init_lr(self):
    method update_learning_rate (line 144) | def update_learning_rate(self, current_iter, warmup_iter=-1):
    method get_current_learning_rate (line 167) | def get_current_learning_rate(self):
    method save_network (line 171) | def save_network(self, net, net_label, current_iter, param_key='params'):
    method _print_different_keys_loading (line 202) | def _print_different_keys_loading(self, crt_net, load_net, strict=True):
    method load_network (line 236) | def load_network(self, net, load_path, strict=True, param_key='params'):
    method save_training_state (line 264) | def save_training_state(self, epoch, current_iter):
    method resume_training (line 282) | def resume_training(self, resume_state):
    method reduce_loss_dict (line 297) | def reduce_loss_dict(self, loss_dict):

FILE: basicsr/models/codeformer_dirichlet_video_model.py
  class CodeFormerDirichletVideoModel (line 16) | class CodeFormerDirichletVideoModel(SRModel):
    method feed_data (line 17) | def feed_data(self, data):
    method init_training_settings (line 37) | def init_training_settings(self):
    method calculate_adaptive_weight (line 96) | def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, di...
    method setup_optimizers (line 104) | def setup_optimizers(self):
    method gray_resize_for_identity (line 134) | def gray_resize_for_identity(self, out, size=128):
    method optimize_parameters (line 140) | def optimize_parameters(self, current_iter):
    method test (line 218) | def test(self):
    method dist_validation (line 230) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 234) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 309) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 319) | def get_current_visuals(self):
    method save (line 325) | def save(self, epoch, current_iter):

FILE: basicsr/models/lr_scheduler.py
  class MultiStepRestartLR (line 6) | class MultiStepRestartLR(_LRScheduler):
    method __init__ (line 19) | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), r...
    method get_lr (line 27) | def get_lr(self):
  function get_position_from_periods (line 36) | def get_position_from_periods(iteration, cumulative_period):
  class CosineAnnealingRestartLR (line 57) | class CosineAnnealingRestartLR(_LRScheduler):
    method __init__ (line 77) | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=...
    method get_lr (line 86) | def get_lr(self):

FILE: basicsr/models/sr_model.py
  class SRModel (line 14) | class SRModel(BaseModel):
    method __init__ (line 17) | def __init__(self, opt):
    method init_training_settings (line 34) | def init_training_settings(self):
    method setup_optimizers (line 72) | def setup_optimizers(self):
    method feed_data (line 86) | def feed_data(self, data):
    method optimize_parameters (line 91) | def optimize_parameters(self, current_iter):
    method test (line 120) | def test(self):
    method dist_validation (line 131) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 135) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 186) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 196) | def get_current_visuals(self):
    method save (line 204) | def save(self, epoch, current_iter):

FILE: basicsr/models/vqgan_model.py
  class VQGANModel (line 16) | class VQGANModel(SRModel):
    method feed_data (line 17) | def feed_data(self, data):
    method init_training_settings (line 22) | def init_training_settings(self):
    method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, di...
    method adopt_weight (line 93) | def adopt_weight(self, weight, global_step, threshold=0, value=0.):
    method setup_optimizers (line 98) | def setup_optimizers(self):
    method optimize_parameters (line 117) | def optimize_parameters(self, current_iter):
    method test (line 192) | def test(self):
    method dist_validation (line 205) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 210) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 262) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 273) | def get_current_visuals(self):
    method save (line 279) | def save(self, epoch, current_iter):

FILE: basicsr/ops/dcn/deform_conv.py
  class DeformConvFunction (line 27) | class DeformConvFunction(Function):
    method forward (line 30) | def forward(ctx,
    method backward (line 69) | def backward(ctx, grad_output):
    method _output_size (line 101) | def _output_size(input, weight, padding, dilation, stride):
  class ModulatedDeformConvFunction (line 115) | class ModulatedDeformConvFunction(Function):
    method forward (line 118) | def forward(ctx,
    method backward (line 152) | def backward(ctx, grad_output):
    method _infer_shape (line 172) | def _infer_shape(ctx, input, weight):
  class DeformConv (line 186) | class DeformConv(nn.Module):
    method __init__ (line 188) | def __init__(self,
    method reset_parameters (line 223) | def reset_parameters(self):
    method forward (line 230) | def forward(self, x, offset):
  class DeformConvPack (line 246) | class DeformConvPack(DeformConv):
    method __init__ (line 264) | def __init__(self, *args, **kwargs):
    method init_offset (line 277) | def init_offset(self):
    method forward (line 281) | def forward(self, x):
  class ModulatedDeformConv (line 287) | class ModulatedDeformConv(nn.Module):
    method __init__ (line 289) | def __init__(self,
    method init_weights (line 320) | def init_weights(self):
    method forward (line 329) | def forward(self, x, offset, mask):
  class ModulatedDeformConvPack (line 334) | class ModulatedDeformConvPack(ModulatedDeformConv):
    method __init__ (line 352) | def __init__(self, *args, **kwargs):
    method init_weights (line 365) | def init_weights(self):
    method forward (line 371) | def forward(self, x):

FILE: basicsr/ops/dcn/src/deform_conv_cuda.cpp
  function shape_check (line 62) | void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOu...
  function deform_conv_forward_cuda (line 152) | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
  function deform_conv_backward_input_cuda (line 262) | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
  function deform_conv_backward_parameters_cuda (line 376) | int deform_conv_backward_parameters_cuda(
  function modulated_deform_conv_cuda_forward (line 490) | void modulated_deform_conv_cuda_forward(
  function modulated_deform_conv_cuda_backward (line 571) | void modulated_deform_conv_cuda_backward(

FILE: basicsr/ops/dcn/src/deform_conv_ext.cpp
  function deform_conv_forward (line 52) | int deform_conv_forward(at::Tensor input, at::Tensor weight,
  function deform_conv_backward_input (line 70) | int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
  function deform_conv_backward_parameters (line 89) | int deform_conv_backward_parameters(
  function modulated_deform_conv_forward (line 107) | void modulated_deform_conv_forward(
  function modulated_deform_conv_backward (line 127) | void modulated_deform_conv_backward(
  function PYBIND11_MODULE (line 150) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/fused_act/fused_act.py
  class FusedLeakyReLUFunctionBackward (line 24) | class FusedLeakyReLUFunctionBackward(Function):
    method forward (line 27) | def forward(ctx, grad_output, out, negative_slope, scale):
    method backward (line 46) | def backward(ctx, gradgrad_input, gradgrad_bias):
  class FusedLeakyReLUFunction (line 54) | class FusedLeakyReLUFunction(Function):
    method forward (line 57) | def forward(ctx, input, bias, negative_slope, scale):
    method backward (line 67) | def backward(ctx, grad_output):
  class FusedLeakyReLU (line 75) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 77) | def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
    method forward (line 84) | def forward(self, input):
  function fused_leaky_relu (line 88) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):

FILE: basicsr/ops/fused_act/src/fused_bias_act.cpp
  function fused_bias_act (line 14) | torch::Tensor fused_bias_act(const torch::Tensor& input,
  function PYBIND11_MODULE (line 24) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
  function upfirdn2d (line 13) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
  function PYBIND11_MODULE (line 22) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/upfirdn2d/upfirdn2d.py
  class UpFirDn2dBackward (line 24) | class UpFirDn2dBackward(Function):
    method forward (line 27) | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pa...
    method backward (line 67) | def backward(ctx, gradgrad_input):
  class UpFirDn2d (line 91) | class UpFirDn2d(Function):
    method forward (line 94) | def forward(ctx, input, kernel, up, down, pad):
    method backward (line 129) | def backward(ctx, grad_output):
  function upfirdn2d (line 147) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  function upfirdn2d_native (line 156) | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, ...

FILE: basicsr/setup.py
  function readme (line 15) | def readme():
  function get_git_hash (line 21) | def get_git_hash():
  function get_hash (line 46) | def get_hash():
  function write_version_py (line 61) | def write_version_py():
  function get_version (line 78) | def get_version():
  function make_cuda_ext (line 84) | def make_cuda_ext(name, module, sources, sources_cuda=None):
  function get_requirements (line 111) | def get_requirements(filename='requirements.txt'):

FILE: basicsr/train.py
  function parse_options (line 26) | def parse_options(root_path, is_train=True):
  function init_loggers (line 61) | def init_loggers(opt):
  function create_train_val_dataloader (line 77) | def create_train_val_dataloader(opt, logger):
  function train_pipeline (line 117) | def train_pipeline(root_path):

FILE: basicsr/utils/dist_util.py
  function init_dist (line 10) | def init_dist(launcher, backend='nccl', **kwargs):
  function _init_dist_pytorch (line 21) | def _init_dist_pytorch(backend, **kwargs):
  function _init_dist_slurm (line 30) | def _init_dist_slurm(backend, port=None):
  function get_dist_info (line 62) | def get_dist_info():
  function master_only (line 76) | def master_only(func):

FILE: basicsr/utils/download_util.py
  function download_file_from_google_drive (line 11) | def download_file_from_google_drive(file_id, save_path):
  function get_confirm_token (line 41) | def get_confirm_token(response):
  function save_response_content (line 48) | def save_response_content(response, destination, file_size=None, chunk_s...
  function load_file_from_url (line 69) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):

FILE: basicsr/utils/file_client.py
  class BaseStorageBackend (line 5) | class BaseStorageBackend(metaclass=ABCMeta):
    method get (line 14) | def get(self, filepath):
    method get_text (line 18) | def get_text(self, filepath):
  class MemcachedBackend (line 22) | class MemcachedBackend(BaseStorageBackend):
    method __init__ (line 32) | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
    method get (line 47) | def get(self, filepath):
    method get_text (line 54) | def get_text(self, filepath):
  class HardDiskBackend (line 58) | class HardDiskBackend(BaseStorageBackend):
    method get (line 61) | def get(self, filepath):
    method get_text (line 67) | def get_text(self, filepath):
  class LmdbBackend (line 74) | class LmdbBackend(BaseStorageBackend):
    method __init__ (line 94) | def __init__(self, db_paths, client_keys='default', readonly=True, loc...
    method get (line 114) | def get(self, filepath, client_key):
    method get_text (line 128) | def get_text(self, filepath):
  class FileClient (line 132) | class FileClient(object):
    method __init__ (line 151) | def __init__(self, backend='disk', **kwargs):
    method get (line 158) | def get(self, filepath, client_key='default'):
    method get_text (line 166) | def get_text(self, filepath):

FILE: basicsr/utils/img_util.py
  function img2tensor (line 10) | def img2tensor(imgs, bgr2rgb=True, float32=True):
  function tensor2img (line 40) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  function tensor2img_fast (line 99) | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
  function tensor2imgs (line 116) | def tensor2imgs(tensor, rgb2bgr=True, min_max=(0, 1)):
  function imfrombytes (line 143) | def imfrombytes(content, flag='color', float32=False):
  function imwrite (line 164) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function images_to_gif (line 182) | def images_to_gif(image_list, output_path, duration=100, loop=0):
  function crop_border (line 217) | def crop_border(imgs, crop_border):

FILE: basicsr/utils/lmdb_util.py
  function make_lmdb_from_imgs (line 9) | def make_lmdb_from_imgs(data_path,
  function read_img_worker (line 132) | def read_img_worker(path, key, compress_level):
  class LmdbMaker (line 156) | class LmdbMaker():
    method __init__ (line 167) | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_l...
    method put (line 182) | def put(self, img_byte, key, img_shape):
    method close (line 193) | def close(self):

FILE: basicsr/utils/logger.py
  class MessageLogger (line 10) | class MessageLogger():
    method __init__ (line 22) | def __init__(self, opt, start_iter=1, tb_logger=None):
    method __call__ (line 33) | def __call__(self, log_vars):
  function init_tb_logger (line 78) | def init_tb_logger(log_dir):
  function init_wandb_logger (line 85) | def init_wandb_logger(opt):
  function get_root_logger (line 107) | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_f...
  function get_env_info (line 147) | def get_env_info():

FILE: basicsr/utils/matlab_functions.py
  function cubic (line 6) | def cubic(x):
  function calculate_weights_indices (line 16) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
  function imresize (line 86) | def imresize(img, scale, antialiasing=True):
  function rgb2ycbcr (line 169) | def rgb2ycbcr(img, y_only=False):
  function bgr2ycbcr (line 202) | def bgr2ycbcr(img, y_only=False):
  function ycbcr2rgb (line 235) | def ycbcr2rgb(img):
  function ycbcr2bgr (line 264) | def ycbcr2bgr(img):
  function _convert_input_type_range (line 293) | def _convert_input_type_range(img):
  function _convert_output_type_range (line 320) | def _convert_output_type_range(img, dst_type):

FILE: basicsr/utils/misc.py
  function gpu_is_available (line 15) | def gpu_is_available():
  function get_device (line 21) | def get_device(gpu_id=None):
  function set_random_seed (line 35) | def set_random_seed(seed):
  function get_time_str (line 44) | def get_time_str():
  function mkdir_and_rename (line 48) | def mkdir_and_rename(path):
  function make_exp_dirs (line 62) | def make_exp_dirs(opt):
  function scandir (line 74) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
  function check_resume (line 116) | def check_resume(opt, resume_iter):
  function sizeof_fmt (line 143) | def sizeof_fmt(size, suffix='B'):

FILE: basicsr/utils/options.py
  function ordered_yaml (line 7) | def ordered_yaml():
  function parse (line 32) | def parse(opt_path, root_path, is_train=True):
  function dict2str (line 90) | def dict2str(opt, indent_level=1):

FILE: basicsr/utils/realesrgan_utils.py
  class RealESRGANer (line 14) | class RealESRGANer():
    method __init__ (line 29) | def __init__(self,
    method pre_process (line 71) | def pre_process(self, img):
    method process (line 96) | def process(self):
    method tile_process (line 100) | def tile_process(self):
    method post_process (line 165) | def post_process(self):
    method enhance (line 177) | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  class PrefetchReader (line 255) | class PrefetchReader(threading.Thread):
    method __init__ (line 263) | def __init__(self, img_list, num_prefetch_queue):
    method run (line 268) | def run(self):
    method __next__ (line 275) | def __next__(self):
    method __iter__ (line 281) | def __iter__(self):
  class IOConsumer (line 285) | class IOConsumer(threading.Thread):
    method __init__ (line 287) | def __init__(self, opt, que, qid):
    method run (line 293) | def run(self):

FILE: basicsr/utils/registry.py
  class Registry (line 4) | class Registry():
    method __init__ (line 30) | def __init__(self, name):
    method _do_register (line 38) | def _do_register(self, name, obj):
    method register (line 43) | def register(self, obj=None):
    method get (line 62) | def get(self, name):
    method __contains__ (line 68) | def __contains__(self, name):
    method __iter__ (line 71) | def __iter__(self):
    method keys (line 74) | def keys(self):

FILE: basicsr/utils/video_util.py
  function get_video_meta_info (line 17) | def get_video_meta_info(video_path):
  class VideoReader (line 29) | class VideoReader:
    method __init__ (line 30) | def __init__(self, video_path):
    method get_resolution (line 52) | def get_resolution(self):
    method get_fps (line 55) | def get_fps(self):
    method get_audio (line 60) | def get_audio(self):
    method __len__ (line 63) | def __len__(self):
    method get_frame_from_stream (line 66) | def get_frame_from_stream(self):
    method get_frame_from_list (line 73) | def get_frame_from_list(self):
    method get_frame (line 80) | def get_frame(self):
    method close (line 84) | def close(self):
  class VideoWriter (line 89) | class VideoWriter:
    method __init__ (line 90) | def __init__(self, video_save_path, height, width, fps, audio):
    method write_frame (line 113) | def write_frame(self, frame):
    method close (line 123) | def close(self):

FILE: facelib/detection/__init__.py
  function init_detection_model (line 14) | def init_detection_model(model_name, half=False, device='cuda'):
  function init_retinaface_model (line 25) | def init_retinaface_model(model_name, half=False, device='cuda'):
  function init_yolov5face_model (line 49) | def init_yolov5face_model(model_name, device='cuda'):

FILE: facelib/detection/align_trans.py
  class FaceWarpException (line 13) | class FaceWarpException(Exception):
    method __str__ (line 15) | def __str__(self):
  function get_reference_facial_points (line 19) | def get_reference_facial_points(output_size=None, inner_padding_factor=0...
  function get_affine_transform_matrix (line 112) | def get_affine_transform_matrix(src_pts, dst_pts):
  function warp_and_crop_face (line 145) | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_siz...

FILE: facelib/detection/matlab_cp2tform.py
  class MatlabCp2tormException (line 7) | class MatlabCp2tormException(Exception):
    method __str__ (line 9) | def __str__(self):
  function tformfwd (line 13) | def tformfwd(trans, uv):
  function tforminv (line 37) | def tforminv(trans, uv):
  function findNonreflectiveSimilarity (line 60) | def findNonreflectiveSimilarity(uv, xy, options=None):
  function findSimilarity (line 94) | def findSimilarity(uv, xy, options=None):
  function get_similarity_transform (line 130) | def get_similarity_transform(src_pts, dst_pts, reflective=True):
  function cvt_tform_mat_for_cv2 (line 170) | def cvt_tform_mat_for_cv2(trans):
  function get_similarity_transform_for_cv2 (line 198) | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):

FILE: facelib/detection/retinaface/retinaface.py
  function generate_config (line 19) | def generate_config(network_name):
  class RetinaFace (line 75) | class RetinaFace(nn.Module):
    method __init__ (line 77) | def __init__(self, network_name='resnet50', half=False, phase='test'):
    method forward (line 122) | def forward(self, inputs):
    method __detect_faces (line 147) | def __detect_faces(self, inputs):
    method transform (line 167) | def transform(self, image, use_origin_size):
    method detect_faces (line 194) | def detect_faces(
    method __align_multi (line 241) | def __align_multi(self, image, boxes, landmarks, limit=None):
    method align_multi (line 259) | def align_multi(self, img, conf_threshold=0.8, limit=None):
    method batched_transform (line 267) | def batched_transform(self, frames, use_origin_size):
    method batched_detect_faces (line 310) | def batched_detect_faces(self, frames, conf_threshold=0.8, nms_thresho...

FILE: facelib/detection/retinaface/retinaface_net.py
  function conv_bn (line 6) | def conv_bn(inp, oup, stride=1, leaky=0):
  function conv_bn_no_relu (line 12) | def conv_bn_no_relu(inp, oup, stride):
  function conv_bn1X1 (line 19) | def conv_bn1X1(inp, oup, stride, leaky=0):
  function conv_dw (line 25) | def conv_dw(inp, oup, stride, leaky=0.1):
  class SSH (line 36) | class SSH(nn.Module):
    method __init__ (line 38) | def __init__(self, in_channel, out_channel):
    method forward (line 52) | def forward(self, input):
  class FPN (line 66) | class FPN(nn.Module):
    method __init__ (line 68) | def __init__(self, in_channels_list, out_channels):
    method forward (line 80) | def forward(self, input):
  class MobileNetV1 (line 100) | class MobileNetV1(nn.Module):
    method __init__ (line 102) | def __init__(self):
    method forward (line 127) | def forward(self, x):
  class ClassHead (line 138) | class ClassHead(nn.Module):
    method __init__ (line 140) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 145) | def forward(self, x):
  class BboxHead (line 152) | class BboxHead(nn.Module):
    method __init__ (line 154) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 158) | def forward(self, x):
  class LandmarkHead (line 165) | class LandmarkHead(nn.Module):
    method __init__ (line 167) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 171) | def forward(self, x):
  function make_class_head (line 178) | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
  function make_bbox_head (line 185) | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
  function make_landmark_head (line 192) | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):

FILE: facelib/detection/retinaface/retinaface_utils.py
  class PriorBox (line 8) | class PriorBox(object):
    method __init__ (line 10) | def __init__(self, cfg, image_size=None, phase='train'):
    method forward (line 19) | def forward(self):
  function py_cpu_nms (line 39) | def py_cpu_nms(dets, thresh):
  function point_form (line 50) | def point_form(boxes):
  function center_size (line 65) | def center_size(boxes):
  function intersect (line 79) | def intersect(box_a, box_b):
  function jaccard (line 98) | def jaccard(box_a, box_b):
  function matrix_iou (line 117) | def matrix_iou(a, b):
  function matrix_iof (line 130) | def matrix_iof(a, b):
  function match (line 142) | def match(threshold, truths, priors, variances, labels, landms, loc_t, c...
  function encode (line 200) | def encode(matched, priors, variances):
  function encode_landm (line 224) | def encode_landm(matched, priors, variances):
  function decode (line 254) | def decode(loc, priors, variances):
  function decode_landm (line 274) | def decode_landm(pre, priors, variances):
  function batched_decode (line 297) | def batched_decode(b_loc, priors, variances):
  function batched_decode_landm (line 320) | def batched_decode_landm(pre, priors, variances):
  function log_sum_exp (line 343) | def log_sum_exp(x):
  function nms (line 357) | def nms(boxes, scores, overlap=0.5, top_k=200):

FILE: facelib/detection/yolov5face/face_detector.py
  function isListempty (line 22) | def isListempty(inList):
  class YoloDetector (line 27) | class YoloDetector:
    method __init__ (line 28) | def __init__(
    method _preprocess (line 48) | def _preprocess(self, imgs):
    method _postprocess (line 69) | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
    method detect_faces (line 104) | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
    method __call__ (line 140) | def __call__(self, *args):

FILE: facelib/detection/yolov5face/models/common.py
  function autopad (line 18) | def autopad(k, p=None):  # kernel, padding
  function channel_shuffle (line 25) | def channel_shuffle(x, groups):
  function DWConv (line 37) | def DWConv(c1, c2, k=1, s=1, act=True):
  class Conv (line 42) | class Conv(nn.Module):
    method __init__ (line 44) | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in,...
    method forward (line 50) | def forward(self, x):
    method fuseforward (line 53) | def fuseforward(self, x):
  class StemBlock (line 57) | class StemBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
    method forward (line 66) | def forward(self, x):
  class Bottleneck (line 74) | class Bottleneck(nn.Module):
    method __init__ (line 76) | def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_ou...
    method forward (line 83) | def forward(self, x):
  class BottleneckCSP (line 87) | class BottleneckCSP(nn.Module):
    method __init__ (line 89) | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ...
    method forward (line 100) | def forward(self, x):
  class C3 (line 106) | class C3(nn.Module):
    method __init__ (line 108) | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ...
    method forward (line 116) | def forward(self, x):
  class ShuffleV2Block (line 120) | class ShuffleV2Block(nn.Module):
    method __init__ (line 121) | def __init__(self, inp, oup, stride):
    method depthwise_conv (line 160) | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
    method forward (line 163) | def forward(self, x):
  class SPP (line 173) | class SPP(nn.Module):
    method __init__ (line 175) | def __init__(self, c1, c2, k=(5, 9, 13)):
    method forward (line 182) | def forward(self, x):
  class Focus (line 187) | class Focus(nn.Module):
    method __init__ (line 189) | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in,...
    method forward (line 193) | def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  class Concat (line 197) | class Concat(nn.Module):
    method __init__ (line 199) | def __init__(self, dimension=1):
    method forward (line 203) | def forward(self, x):
  class NMS (line 207) | class NMS(nn.Module):
    method forward (line 213) | def forward(self, x):
  class AutoShape (line 217) | class AutoShape(nn.Module):
    method __init__ (line 224) | def __init__(self, model):
    method autoshape (line 228) | def autoshape(self):
    method forward (line 232) | def forward(self, imgs, size=640, augment=False, profile=False):
  class Detections (line 275) | class Detections:
    method __init__ (line 277) | def __init__(self, imgs, pred, names=None):
    method __len__ (line 290) | def __len__(self):
    method tolist (line 293) | def tolist(self):

FILE: facelib/detection/yolov5face/models/experimental.py
  class CrossConv (line 10) | class CrossConv(nn.Module):
    method __init__ (line 12) | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
    method forward (line 20) | def forward(self, x):
  class MixConv2d (line 24) | class MixConv2d(nn.Module):
    method __init__ (line 26) | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
    method forward (line 44) | def forward(self, x):

FILE: facelib/detection/yolov5face/models/yolo.py
  class Detect (line 29) | class Detect(nn.Module):
    method __init__ (line 33) | def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
    method forward (line 46) | def forward(self, x):
    method _make_grid (line 89) | def _make_grid(nx=20, ny=20):
  class Model (line 95) | class Model(nn.Module):
    method __init__ (line 96) | def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None):  # model, input...
    method forward (line 120) | def forward(self, x):
    method forward_once (line 123) | def forward_once(self, x):
    method _initialize_biases (line 134) | def _initialize_biases(self, cf=None):  # initialize biases into Detec...
    method _print_biases (line 143) | def _print_biases(self):
    method fuse (line 149) | def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
    method nms (line 160) | def nms(self, mode=True):  # add or remove NMS module
    method autoshape (line 174) | def autoshape(self):  # add autoShape module
  function parse_model (line 181) | def parse_model(d, ch):  # model_dict, input_channels(3)

FILE: facelib/detection/yolov5face/utils/autoanchor.py
  function check_anchor_order (line 4) | def check_anchor_order(m):

FILE: facelib/detection/yolov5face/utils/datasets.py
  function letterbox (line 5) | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=Tru...

FILE: facelib/detection/yolov5face/utils/general.py
  function check_img_size (line 9) | def check_img_size(img_size, s=32):
  function make_divisible (line 17) | def make_divisible(x, divisor):
  function xyxy2xywh (line 22) | def xyxy2xywh(x):
  function xywh2xyxy (line 32) | def xywh2xyxy(x):
  function scale_coords (line 42) | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  function clip_coords (line 58) | def clip_coords(boxes, img_shape):
  function box_iou (line 66) | def box_iou(box1, box2):
  function non_max_suppression_face (line 89) | def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45...
  function non_max_suppression (line 168) | def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, cla...
  function scale_coords_landmarks (line 249) | def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):

FILE: facelib/detection/yolov5face/utils/torch_utils.py
  function fuse_conv_and_bn (line 5) | def fuse_conv_and_bn(conv, bn):
  function copy_attr (line 34) | def copy_attr(a, b, include=(), exclude=()):

FILE: facelib/parsing/__init__.py
  function init_parsing_model (line 8) | def init_parsing_model(model_name='bisenet', half=False, device='cuda'):

FILE: facelib/parsing/bisenet.py
  class ConvBNReLU (line 8) | class ConvBNReLU(nn.Module):
    method __init__ (line 10) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
    method forward (line 15) | def forward(self, x):
  class BiSeNetOutput (line 21) | class BiSeNetOutput(nn.Module):
    method __init__ (line 23) | def __init__(self, in_chan, mid_chan, num_class):
    method forward (line 28) | def forward(self, x):
  class AttentionRefinementModule (line 34) | class AttentionRefinementModule(nn.Module):
    method __init__ (line 36) | def __init__(self, in_chan, out_chan):
    method forward (line 43) | def forward(self, x):
  class ContextPath (line 53) | class ContextPath(nn.Module):
    method __init__ (line 55) | def __init__(self):
    method forward (line 64) | def forward(self, x):
  class FeatureFusionModule (line 87) | class FeatureFusionModule(nn.Module):
    method __init__ (line 89) | def __init__(self, in_chan, out_chan):
    method forward (line 97) | def forward(self, fsp, fcp):
  class BiSeNet (line 110) | class BiSeNet(nn.Module):
    method __init__ (line 112) | def __init__(self, num_class):
    method forward (line 120) | def forward(self, x, return_feat=False):

FILE: facelib/parsing/parsenet.py
  class NormLayer (line 8) | class NormLayer(nn.Module):
    method __init__ (line 16) | def __init__(self, channels, normalize_shape=None, norm_type='bn'):
    method forward (line 35) | def forward(self, x, ref=None):
  class ReluLayer (line 42) | class ReluLayer(nn.Module):
    method __init__ (line 54) | def __init__(self, channels, relu_type='relu'):
    method forward (line 70) | def forward(self, x):
  class ConvLayer (line 74) | class ConvLayer(nn.Module):
    method __init__ (line 76) | def __init__(self,
    method forward (line 103) | def forward(self, x):
  class ResidualBlock (line 113) | class ResidualBlock(nn.Module):
    method __init__ (line 118) | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', sca...
    method forward (line 132) | def forward(self, x):
  class ParseNet (line 140) | class ParseNet(nn.Module):
    method __init__ (line 142) | def __init__(self,
    method forward (line 188) | def forward(self, x):

FILE: facelib/parsing/resnet.py
  function conv3x3 (line 5) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 10) | class BasicBlock(nn.Module):
    method __init__ (line 12) | def __init__(self, in_chan, out_chan, stride=1):
    method forward (line 26) | def forward(self, x):
  function create_layer_basic (line 41) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
  class ResNet18 (line 48) | class ResNet18(nn.Module):
    method __init__ (line 50) | def __init__(self):
    method forward (line 60) | def forward(self, x):

FILE: facelib/utils/face_restoration_helper.py
  function get_largest_face (line 25) | def get_largest_face(det_faces, h, w):
  function get_center_face (line 47) | def get_center_face(det_faces, h=0, w=0, center=None):
  class FaceRestoreHelper (line 63) | class FaceRestoreHelper(object):
    method __init__ (line 66) | def __init__(
    method set_upscale_factor (line 159) | def set_upscale_factor(self, upscale_factor):
    method read_image (line 162) | def read_image(self, img):
    method init_dlib (line 185) | def init_dlib(self, detection_path, landmark5_path):
    method get_face_landmarks_5_dlib (line 199) | def get_face_landmarks_5_dlib(self,
    method get_face_landmarks_5 (line 231) | def get_face_landmarks_5(self,
    method align_warp_face (line 362) | def align_warp_face(self, save_cropped_path=None, border_mode='constan...
    method get_inverse_affine (line 402) | def get_inverse_affine(self, save_inverse_affine_path=None):
    method add_restored_face (line 414) | def add_restored_face(self, restored_face, input_face=None):
    method paste_faces_to_input_image (line 423) | def paste_faces_to_input_image(
    method clean_all (line 588) | def clean_all(self):
  class FaceAligner (line 598) | class FaceAligner(object):
    method __init__ (line 599) | def __init__(self,
    method set_image (line 664) | def set_image(self, img):
    method align_pair_face (line 667) | def align_pair_face(self, img_lq, img_gt, landmarks):
    method align_single_face (line 677) | def align_single_face(self, img, landmarks, border_mode='constant'):
    method align_warp_face (line 698) | def align_warp_face(self, img_lq, img_gt, landmarks, border_mode='cons...
    method clean_all (line 726) | def clean_all(self):

FILE: facelib/utils/face_utils.py
  function compute_increased_bbox (line 6) | def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
  function get_valid_bboxes (line 23) | def get_valid_bboxes(bboxes, h, w):
  function align_crop_face_landmarks (line 31) | def align_crop_face_landmarks(img,
  function paste_face_back (line 190) | def paste_face_back(img, face, inverse_affine):

FILE: facelib/utils/misc.py
  function download_pretrained_models (line 14) | def download_pretrained_models(file_ids, save_path_root):
  function imwrite (line 38) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function img2tensor (line 57) | def img2tensor(imgs, bgr2rgb=True, float32=True):
  function load_file_from_url (line 86) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
  function scandir (line 106) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
  function is_gray (line 146) | def is_gray(img, threshold=10):
  function rgb2gray (line 162) | def rgb2gray(img, out_channel=3):
  function bgr2gray (line 169) | def bgr2gray(img, out_channel=3):
  function calc_mean_std (line 177) | def calc_mean_std(feat, eps=1e-5):
  function adain_npy (line 191) | def adain_npy(content_feat, style_feat):

FILE: scripts/inference.py
  function interpolate_sequence (line 19) | def interpolate_sequence(sequence):
  function set_realesrgan (line 35) | def set_realesrgan():

FILE: scripts/inference_color_and_inpainting.py
  function interpolate_sequence (line 19) | def interpolate_sequence(sequence):
  function set_realesrgan (line 35) | def set_realesrgan():

FILE: scripts/warp_images.py
  function interpolate_sequence (line 23) | def interpolate_sequence(sequence):
  function process_single (line 38) | def process_single(args, face_helper, input_path, ldmk_folder_path):
Condensed preview — 100 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (733K chars).
[
  {
    "path": ".gitignore",
    "chars": 264,
    "preview": "# 忽略操作系统生成的文件\n.DS_Store\nThumbs.db\n\n# 忽略编译生成的文件\n*.class\n*.exe\n*.o\n*.so\n.eggs/\n*.egg-info/\n\n\n# 忽略包管理工具生成的文件\nnode_modules/\n"
  },
  {
    "path": "README.md",
    "chars": 13301,
    "preview": "<h1 align='center'>DicFace: Dirichlet-Constrained Variational Codebook Learning for Temporally Coherent Video Face Resto"
  },
  {
    "path": "basicsr/VERSION",
    "chars": 6,
    "preview": "1.3.2\n"
  },
  {
    "path": "basicsr/__init__.py",
    "chars": 266,
    "preview": "# https://github.com/xinntao/BasicSR\n# flake8: noqa\nfrom .archs import *\nfrom .data import *\nfrom .losses import *\nfrom "
  },
  {
    "path": "basicsr/archs/__init__.py",
    "chars": 886,
    "preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
  },
  {
    "path": "basicsr/archs/arcface_arch.py",
    "chars": 8074,
    "preview": "import torch.nn as nn\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef conv3x3(inplanes, outplanes, stride=1):\n   "
  },
  {
    "path": "basicsr/archs/arch_util.py",
    "chars": 11431,
    "preview": "import collections.abc\nimport math\nimport torch\nimport torchvision\nimport warnings\nfrom distutils.version import LooseVe"
  },
  {
    "path": "basicsr/archs/dir_dist_codeformer_multiscale_arch.py",
    "chars": 15751,
    "preview": "import math\nimport numpy as np\nimport torch\nfrom torch import nn, Tensor\nimport torch.nn.functional as F\nfrom typing imp"
  },
  {
    "path": "basicsr/archs/rrdbnet_arch.py",
    "chars": 4620,
    "preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
  },
  {
    "path": "basicsr/archs/vgg_arch.py",
    "chars": 6132,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom torch import nn as nn\nfrom torchvision.models import vgg"
  },
  {
    "path": "basicsr/archs/vqgan_arch.py",
    "chars": 15428,
    "preview": "'''\nVQGAN code, adapted from the original created by the Unleashing Transformers authors:\nhttps://github.com/samb-t/unle"
  },
  {
    "path": "basicsr/data/__init__.py",
    "chars": 4254,
    "preview": "import importlib\nimport numpy as np\nimport random\nimport torch\nimport torch.utils.data\nfrom copy import deepcopy\nfrom fu"
  },
  {
    "path": "basicsr/data/color_dataset.py",
    "chars": 8396,
    "preview": "import os\nimport random\nfrom pathlib import Path\n\nfrom PIL import Image\nimport cv2\nimport ffmpeg\nimport io\nimport av\nimp"
  },
  {
    "path": "basicsr/data/data_sampler.py",
    "chars": 1639,
    "preview": "import math\nimport torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass EnlargedSampler(Sampler):\n    \"\"\"Sampler t"
  },
  {
    "path": "basicsr/data/data_util.py",
    "chars": 17339,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport torch\nfrom os import path as osp\nfrom PIL import Image, ImageDraw\nfrom "
  },
  {
    "path": "basicsr/data/degradations.py",
    "chars": 28194,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport random\nimport torch\nfrom scipy import special\nfrom scipy.stats import m"
  },
  {
    "path": "basicsr/data/gaussian_kernels.py",
    "chars": 25543,
    "preview": "import math\nimport numpy as np\nimport random\nfrom scipy.ndimage.interpolation import shift\nfrom scipy.stats import multi"
  },
  {
    "path": "basicsr/data/inpainting_dataset.py",
    "chars": 8376,
    "preview": "import os\nimport random\nfrom pathlib import Path\n\nfrom PIL import Image\nimport cv2\nimport ffmpeg\nimport io\nimport av\nimp"
  },
  {
    "path": "basicsr/data/paired_image_dataset.py",
    "chars": 4537,
    "preview": "from torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.data.data_uti"
  },
  {
    "path": "basicsr/data/prefetch_dataloader.py",
    "chars": 3131,
    "preview": "import queue as Queue\nimport threading\nimport torch\nfrom torch.utils.data import DataLoader\n\n\nclass PrefetchGenerator(th"
  },
  {
    "path": "basicsr/data/transforms.py",
    "chars": 5571,
    "preview": "import cv2\nimport random\n\n\ndef mod_crop(img, scale):\n    \"\"\"Mod crop images, used during testing.\n\n    Args:\n        img"
  },
  {
    "path": "basicsr/data/vfhq_dataset.py",
    "chars": 12306,
    "preview": "import os\nimport random\nfrom pathlib import Path\n\nfrom PIL import Image\nimport cv2\nimport ffmpeg\nimport io\nimport av\nimp"
  },
  {
    "path": "basicsr/losses/__init__.py",
    "chars": 836,
    "preview": "from copy import deepcopy\n\nfrom basicsr.utils import get_root_logger\nfrom basicsr.utils.registry import LOSS_REGISTRY\nfr"
  },
  {
    "path": "basicsr/losses/loss_util.py",
    "chars": 2903,
    "preview": "import functools\nfrom torch.nn import functional as F\n\n\ndef reduce_loss(loss, reduction):\n    \"\"\"Reduce loss as specifie"
  },
  {
    "path": "basicsr/losses/losses.py",
    "chars": 17690,
    "preview": "import math\nimport lpips\nimport torch\nfrom torch import autograd as autograd\nfrom torch import nn as nn\nfrom torch.nn im"
  },
  {
    "path": "basicsr/metrics/__init__.py",
    "chars": 507,
    "preview": "from copy import deepcopy\n\nfrom basicsr.utils.registry import METRIC_REGISTRY\nfrom .psnr_ssim import calculate_psnr, cal"
  },
  {
    "path": "basicsr/metrics/metric_util.py",
    "chars": 1288,
    "preview": "import numpy as np\n\nfrom basicsr.utils.matlab_functions import bgr2ycbcr\n\n\ndef reorder_image(img, input_order='HWC'):\n  "
  },
  {
    "path": "basicsr/metrics/psnr_ssim.py",
    "chars": 4563,
    "preview": "import cv2\nimport numpy as np\n\nfrom basicsr.metrics.metric_util import reorder_image, to_y_channel\nfrom basicsr.utils.re"
  },
  {
    "path": "basicsr/models/__init__.py",
    "chars": 1014,
    "preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
  },
  {
    "path": "basicsr/models/base_model.py",
    "chars": 12482,
    "preview": "import logging\nimport os\nimport torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom torch.nn.parall"
  },
  {
    "path": "basicsr/models/codeformer_dirichlet_video_model.py",
    "chars": 14882,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\nfrom einops import rea"
  },
  {
    "path": "basicsr/models/lr_scheduler.py",
    "chars": 3956,
    "preview": "import math\nfrom collections import Counter\nfrom torch.optim.lr_scheduler import _LRScheduler\n\n\nclass MultiStepRestartLR"
  },
  {
    "path": "basicsr/models/sr_model.py",
    "chars": 8264,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/models/vqgan_model.py",
    "chars": 11645,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/ops/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "basicsr/ops/dcn/__init__.py",
    "chars": 306,
    "preview": "from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,\n       "
  },
  {
    "path": "basicsr/ops/dcn/deform_conv.py",
    "chars": 15574,
    "preview": "import math\nimport torch\nfrom torch import nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function imp"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_cuda.cpp",
    "chars": 28838,
    "preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu",
    "chars": 42622,
    "preview": "/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contribu"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_ext.cpp",
    "chars": 7492,
    "preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
  },
  {
    "path": "basicsr/ops/fused_act/__init__.py",
    "chars": 106,
    "preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\n\n__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']\n"
  },
  {
    "path": "basicsr/ops/fused_act/fused_act.py",
    "chars": 2717,
    "preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501\n\nimport torch\nfrom"
  },
  {
    "path": "basicsr/ops/fused_act/src/fused_bias_act.cpp",
    "chars": 1092,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp\n#include <torch/extension.h>\n\n"
  },
  {
    "path": "basicsr/ops/fused_act/src/fused_bias_act_kernel.cu",
    "chars": 2874,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu\n// Copyright (c) 2019, N"
  },
  {
    "path": "basicsr/ops/upfirdn2d/__init__.py",
    "chars": 58,
    "preview": "from .upfirdn2d import upfirdn2d\n\n__all__ = ['upfirdn2d']\n"
  },
  {
    "path": "basicsr/ops/upfirdn2d/src/upfirdn2d.cpp",
    "chars": 1052,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp\n#include <torch/extension.h>\n\n\ntorc"
  },
  {
    "path": "basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu",
    "chars": 11803,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu\n// Copyright (c) 2019, NVIDIA"
  },
  {
    "path": "basicsr/ops/upfirdn2d/upfirdn2d.py",
    "chars": 5865,
    "preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py  # noqa:E501\n\nimport torch\nfro"
  },
  {
    "path": "basicsr/setup.py",
    "chars": 5205,
    "preview": "#!/usr/bin/env python\n\nfrom setuptools import find_packages, setup\n\nimport os\nimport subprocess\nimport sys\nimport time\nf"
  },
  {
    "path": "basicsr/train.py",
    "chars": 9735,
    "preview": "import argparse\nimport datetime\nimport logging\nimport math\nimport copy\nimport random\nimport time\nimport torch\nfrom os im"
  },
  {
    "path": "basicsr/utils/__init__.py",
    "chars": 862,
    "preview": "from .file_client import FileClient\nfrom .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, ten"
  },
  {
    "path": "basicsr/utils/dist_util.py",
    "chars": 2706,
    "preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py  # noqa: E501\nimport functools\n"
  },
  {
    "path": "basicsr/utils/download_util.py",
    "chars": 3369,
    "preview": "import math\nimport os\nimport requests\nfrom torch.hub import download_url_to_file, get_dir\nfrom tqdm import tqdm\nfrom url"
  },
  {
    "path": "basicsr/utils/file_client.py",
    "chars": 6017,
    "preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501\nfrom abc import "
  },
  {
    "path": "basicsr/utils/img_util.py",
    "chars": 8073,
    "preview": "import cv2\nimport math\nimport numpy as np\nfrom PIL import Image\nimport os\nimport torch\nfrom torchvision.utils import mak"
  },
  {
    "path": "basicsr/utils/lmdb_util.py",
    "chars": 7105,
    "preview": "import cv2\nimport lmdb\nimport sys\nfrom multiprocessing import Pool\nfrom os import path as osp\nfrom tqdm import tqdm\n\n\nde"
  },
  {
    "path": "basicsr/utils/logger.py",
    "chars": 6463,
    "preview": "import datetime\nimport logging\nimport time\n\nfrom .dist_util import get_dist_info, master_only\n\ninitialized_logger = {}\n\n"
  },
  {
    "path": "basicsr/utils/matlab_functions.py",
    "chars": 13523,
    "preview": "import math\nimport numpy as np\nimport torch\n\n\ndef cubic(x):\n    \"\"\"cubic function used for calculate_weights_indices.\"\"\""
  },
  {
    "path": "basicsr/utils/misc.py",
    "chars": 5193,
    "preview": "import os\nimport re\nimport random\nimport time\nimport torch\nimport numpy as np\nfrom os import path as osp\n\nfrom .dist_uti"
  },
  {
    "path": "basicsr/utils/options.py",
    "chars": 3496,
    "preview": "import yaml\nimport time\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom basicsr.utils.misc import ge"
  },
  {
    "path": "basicsr/utils/realesrgan_utils.py",
    "chars": 12264,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport queue\nimport threading\nimport torch\nfrom torch.nn import func"
  },
  {
    "path": "basicsr/utils/registry.py",
    "chars": 2185,
    "preview": "# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py  # noqa: E501\n\n\nclass "
  },
  {
    "path": "basicsr/utils/video_util.py",
    "chars": 4578,
    "preview": "'''\nThe code is modified from the Real-ESRGAN:\nhttps://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_v"
  },
  {
    "path": "basicsr/version.py",
    "chars": 128,
    "preview": "# GENERATED VERSION FILE\n# TIME: Thu Jun 26 05:59:40 2025\n__version__ = '1.3.2'\n__gitsha__ = '536df45'\nversion_info = (1"
  },
  {
    "path": "facelib/detection/__init__.py",
    "chars": 4396,
    "preview": "import os\nimport torch\nfrom torch import nn\nfrom copy import deepcopy\n\nfrom facelib.utils import load_file_from_url\nfrom"
  },
  {
    "path": "facelib/detection/align_trans.py",
    "chars": 7941,
    "preview": "import cv2\nimport numpy as np\n\nfrom .matlab_cp2tform import get_similarity_transform_for_cv2\n\n# reference facial points,"
  },
  {
    "path": "facelib/detection/matlab_cp2tform.py",
    "chars": 8109,
    "preview": "import numpy as np\nfrom numpy.linalg import inv, lstsq\nfrom numpy.linalg import matrix_rank as rank\nfrom numpy.linalg im"
  },
  {
    "path": "facelib/detection/retinaface/retinaface.py",
    "chars": 13397,
    "preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\nf"
  },
  {
    "path": "facelib/detection/retinaface/retinaface_net.py",
    "chars": 6281,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef conv_bn(inp, oup, stride=1, leaky=0):\n    retur"
  },
  {
    "path": "facelib/detection/retinaface/retinaface_utils.py",
    "chars": 16362,
    "preview": "import numpy as np\nimport torch\nimport torchvision\nfrom itertools import product as product\nfrom math import ceil\n\n\nclas"
  },
  {
    "path": "facelib/detection/yolov5face/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/face_detector.py",
    "chars": 5894,
    "preview": "import cv2\nimport copy\nimport re\nimport torch\nimport numpy as np\n\nfrom pathlib import Path\nfrom facelib.detection.yolov5"
  },
  {
    "path": "facelib/detection/yolov5face/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/models/common.py",
    "chars": 11561,
    "preview": "# This file contains modules common to various models\n\nimport math\n\nimport numpy as np\nimport torch\nfrom torch import nn"
  },
  {
    "path": "facelib/detection/yolov5face/models/experimental.py",
    "chars": 1723,
    "preview": "# # This file contains experimental modules\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom facelib.detectio"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolo.py",
    "chars": 9730,
    "preview": "import math\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport torch\nimport yaml  # for torch hub\nfrom torch imp"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolov5l.yaml",
    "chars": 1344,
    "preview": "# parameters\nnc: 1  # number of classes\ndepth_multiple: 1.0  # model depth multiple\nwidth_multiple: 1.0  # layer channel"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolov5n.yaml",
    "chars": 1335,
    "preview": "# parameters\nnc: 1  # number of classes\ndepth_multiple: 1.0  # model depth multiple\nwidth_multiple: 1.0  # layer channel"
  },
  {
    "path": "facelib/detection/yolov5face/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/utils/autoanchor.py",
    "chars": 460,
    "preview": "# Auto-anchor utils\n\n\ndef check_anchor_order(m):\n    # Check anchor order against stride order for YOLOv5 Detect() modul"
  },
  {
    "path": "facelib/detection/yolov5face/utils/datasets.py",
    "chars": 1515,
    "preview": "import cv2\nimport numpy as np\n\n\ndef letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=Fa"
  },
  {
    "path": "facelib/detection/yolov5face/utils/extract_ckpt.py",
    "chars": 235,
    "preview": "import torch\nimport sys\nsys.path.insert(0,'./facelib/detection/yolov5face')\nmodel = torch.load('facelib/detection/yolov5"
  },
  {
    "path": "facelib/detection/yolov5face/utils/general.py",
    "chars": 10348,
    "preview": "import math\nimport time\n\nimport numpy as np\nimport torch\nimport torchvision\n\n\ndef check_img_size(img_size, s=32):\n    # "
  },
  {
    "path": "facelib/detection/yolov5face/utils/torch_utils.py",
    "chars": 1375,
    "preview": "import torch\nfrom torch import nn\n\n\ndef fuse_conv_and_bn(conv, bn):\n    # Fuse convolution and batchnorm layers https://"
  },
  {
    "path": "facelib/parsing/__init__.py",
    "chars": 959,
    "preview": "import torch\n\nfrom facelib.utils import load_file_from_url\nfrom .bisenet import BiSeNet\nfrom .parsenet import ParseNet\n\n"
  },
  {
    "path": "facelib/parsing/bisenet.py",
    "chars": 5190,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .resnet import ResNet18\n\n\nclass ConvBNReLU(nn.M"
  },
  {
    "path": "facelib/parsing/parsenet.py",
    "chars": 6477,
    "preview": "\"\"\"Modified from https://github.com/chaofengc/PSFRGAN\n\"\"\"\nimport numpy as np\nimport torch.nn as nn\nfrom torch.nn import "
  },
  {
    "path": "facelib/parsing/resnet.py",
    "chars": 2357,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolu"
  },
  {
    "path": "facelib/utils/__init__.py",
    "chars": 389,
    "preview": "from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back\nfrom .misc "
  },
  {
    "path": "facelib/utils/face_restoration_helper.py",
    "chars": 31523,
    "preview": "import cv2\nimport numpy as np\nimport os\nimport torch\nimport pdb\nimport dlib\nfrom torchvision.transforms.functional impor"
  },
  {
    "path": "facelib/utils/face_utils.py",
    "chars": 10061,
    "preview": "import cv2\nimport numpy as np\nimport torch\n\n\ndef compute_increased_bbox(bbox, increase_area, preserve_aspect=True):\n    "
  },
  {
    "path": "facelib/utils/misc.py",
    "chars": 7157,
    "preview": "import cv2\nimport os\nimport os.path as osp\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom torch.hub import d"
  },
  {
    "path": "options/clip5_bs2_512_align_nofix_multiscale.yaml",
    "chars": 4765,
    "preview": "# general settings\nname: BFR_test\nmodel_type: CodeFormerDirichletVideoModel\nnum_gpu: 1\nmanual_seed: 0\n\n# dataset and dat"
  },
  {
    "path": "options/clip5_bs2_512_align_nofix_multiscale_color.yaml",
    "chars": 5589,
    "preview": "# general settings\nname: codeformer_dirichlet_clip5_bs2_align_nofix_multiscale_color\nmodel_type: CodeFormerDirichletVide"
  },
  {
    "path": "options/clip5_bs2_512_align_nofix_multiscale_inpaint.yaml",
    "chars": 5649,
    "preview": "# general settings\nname: codeformer_dirichlet_clip5_bs2_align_nofix_multiscale_inpaint\nmodel_type: CodeFormerDirichletVi"
  },
  {
    "path": "requirements.txt",
    "chars": 176,
    "preview": "addict\nfuture\nlmdb\nnumpy\nopencv-python\nPillow\npyyaml\nrequests\nscikit-image\nscipy\n# tb-nightly\ntensorboard\ntorch>=1.7.1\nt"
  },
  {
    "path": "scripts/inference.py",
    "chars": 16195,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nimport numpy as np\nfrom torchvision.transforms.functional "
  },
  {
    "path": "scripts/inference_color_and_inpainting.py",
    "chars": 13141,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nimport numpy as np\nfrom torchvision.transforms.functional "
  },
  {
    "path": "scripts/warp_images.py",
    "chars": 5668,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nimport pdb\nimport numpy as np\nfrom tqdm import tqdm\nfrom t"
  },
  {
    "path": "train.sh",
    "chars": 186,
    "preview": "CUDA_VISIBLE_DEVICES=0 torchrun \\\n    --nproc_per_node=1 --master_port=29597 \\\n    basicsr/train.py \\\n    -opt options/c"
  }
]

About this extraction

This page contains the full source code of the fudan-generative-vision/DicFace GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 100 files (687.8 KB), approximately 188.1k tokens, and a symbol index with 741 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!