Full Code of sczhou/CodeFormer for AI

master b33cc7d639d6 cached
115 files
692.8 KB
189.9k tokens
722 symbols
1 requests
Download .txt
Showing preview only (729K chars total). Download the full file or copy to clipboard to get everything.
Repository: sczhou/CodeFormer
Branch: master
Commit: b33cc7d639d6
Files: 115
Total size: 692.8 KB

Directory structure:
gitextract_zvvj76bx/

├── .gitignore
├── LICENSE
├── README.md
├── basicsr/
│   ├── VERSION
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── arcface_arch.py
│   │   ├── arch_util.py
│   │   ├── codeformer_arch.py
│   │   ├── rrdbnet_arch.py
│   │   ├── vgg_arch.py
│   │   └── vqgan_arch.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── data_sampler.py
│   │   ├── data_util.py
│   │   ├── ffhq_blind_dataset.py
│   │   ├── ffhq_blind_joint_dataset.py
│   │   ├── gaussian_kernels.py
│   │   ├── paired_image_dataset.py
│   │   ├── prefetch_dataloader.py
│   │   └── transforms.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── loss_util.py
│   │   └── losses.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── metric_util.py
│   │   └── psnr_ssim.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── codeformer_idx_model.py
│   │   ├── codeformer_joint_model.py
│   │   ├── codeformer_model.py
│   │   ├── lr_scheduler.py
│   │   ├── sr_model.py
│   │   └── vqgan_model.py
│   ├── ops/
│   │   ├── __init__.py
│   │   ├── dcn/
│   │   │   ├── __init__.py
│   │   │   ├── deform_conv.py
│   │   │   └── src/
│   │   │       ├── deform_conv_cuda.cpp
│   │   │       ├── deform_conv_cuda_kernel.cu
│   │   │       └── deform_conv_ext.cpp
│   │   ├── fused_act/
│   │   │   ├── __init__.py
│   │   │   ├── fused_act.py
│   │   │   └── src/
│   │   │       ├── fused_bias_act.cpp
│   │   │       └── fused_bias_act_kernel.cu
│   │   └── upfirdn2d/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── upfirdn2d.cpp
│   │       │   └── upfirdn2d_kernel.cu
│   │       └── upfirdn2d.py
│   ├── setup.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── dist_util.py
│       ├── download_util.py
│       ├── file_client.py
│       ├── img_util.py
│       ├── lmdb_util.py
│       ├── logger.py
│       ├── matlab_functions.py
│       ├── misc.py
│       ├── options.py
│       ├── realesrgan_utils.py
│       ├── registry.py
│       └── video_util.py
├── docs/
│   ├── history_changelog.md
│   ├── train.md
│   └── train_CN.md
├── facelib/
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── align_trans.py
│   │   ├── matlab_cp2tform.py
│   │   ├── retinaface/
│   │   │   ├── retinaface.py
│   │   │   ├── retinaface_net.py
│   │   │   └── retinaface_utils.py
│   │   └── yolov5face/
│   │       ├── __init__.py
│   │       ├── face_detector.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── common.py
│   │       │   ├── experimental.py
│   │       │   ├── yolo.py
│   │       │   ├── yolov5l.yaml
│   │       │   └── yolov5n.yaml
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── autoanchor.py
│   │           ├── datasets.py
│   │           ├── extract_ckpt.py
│   │           ├── general.py
│   │           └── torch_utils.py
│   ├── parsing/
│   │   ├── __init__.py
│   │   ├── bisenet.py
│   │   ├── parsenet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── face_restoration_helper.py
│       ├── face_utils.py
│       └── misc.py
├── inference_codeformer.py
├── inference_colorization.py
├── inference_inpainting.py
├── options/
│   ├── CodeFormer_colorization.yml
│   ├── CodeFormer_inpainting.yml
│   ├── CodeFormer_stage2.yml
│   ├── CodeFormer_stage3.yml
│   └── VQGAN_512_ds32_nearest_stage1.yml
├── requirements.txt
├── scripts/
│   ├── crop_align_face.py
│   ├── download_pretrained_models.py
│   ├── download_pretrained_models_from_gdrive.py
│   ├── generate_latent_gt.py
│   └── inference_vqgan.py
├── web-demos/
│   ├── hugging_face/
│   │   └── app.py
│   └── replicate/
│       ├── cog.yaml
│       └── predict.py
└── weights/
    ├── CodeFormer/
    │   └── .gitkeep
    ├── README.md
    └── facelib/
        └── .gitkeep

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.vscode

# ignored files
version.py

# ignored files with suffix
*.html
# *.png
# *.jpeg
# *.jpg
*.pt
*.gif
*.pth
*.dat
*.zip

# template

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# project
results/
experiments/
tb_logger/
run.sh
*debug*
*_old*



================================================
FILE: LICENSE
================================================
S-Lab License 1.0

Copyright 2022 S-Lab

Redistribution and use for non-commercial purpose in source and 
binary forms, with or without modification, are permitted provided 
that the following conditions are met:

1. Redistributions of source code must retain the above copyright 
   notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright 
   notice, this list of conditions and the following disclaimer in 
   the documentation and/or other materials provided with the 
   distribution.

3. Neither the name of the copyright holder nor the names of its 
   contributors may be used to endorse or promote products derived 
   from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

In the event that redistribution and/or use for commercial purpose in 
source or binary forms, with or without modification is required, 
please contact the contributor(s) of the work.

================================================
FILE: README.md
================================================
<p align="center">
  <img src="assets/CodeFormer_logo.png" height=110>
</p>

## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)

[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)


<a href="https://colab.research.google.com/drive/1m52PNveE4PBhYrecj34cnpEeiHcC5LTb?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a> [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer) ![Visitors](https://api.infinitescript.com/badgen/count?name=sczhou/CodeFormer&ltext=Visitors)


[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) 

S-Lab, Nanyang Technological University

<img src="assets/network.jpg" width="800px"/>


:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs: 


### Update
- **2023.07.20**: Integrated to :panda_face: [OpenXLab](https://openxlab.org.cn/apps). Try out online demo! [![OpenXLab](https://img.shields.io/badge/Demo-%F0%9F%90%BC%20OpenXLab-blue)](https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer)
- **2023.04.19**: :whale: Training codes and config files are public available now.
- **2023.04.09**: Add features of inpainting and colorization for cropped and aligned face images.
- **2023.02.10**: Include `dlib` as a new face detector option, it produces more accurate face identity.
- **2022.10.05**: Support video input `--input_path [YOUR_VIDEO.mp4]`. Try it to enhance your videos! :clapper: 
- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
- [**More**](docs/history_changelog.md)

### TODO
- [x] Add training code and config files
- [x] Add checkpoint and script for face inpainting
- [x] Add checkpoint and script for face colorization
- [x] ~~Add background image enhancement~~

#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
[<img src="assets/imgsli_1.jpg" height="226px"/>](https://imgsli.com/MTI3NTE2) [<img src="assets/imgsli_2.jpg" height="226px"/>](https://imgsli.com/MTI3NTE1) [<img src="assets/imgsli_3.jpg" height="226px"/>](https://imgsli.com/MTI3NTIw) 

#### Face Restoration

<img src="assets/restoration_result1.png" width="400px"/> <img src="assets/restoration_result2.png" width="400px"/>
<img src="assets/restoration_result3.png" width="400px"/> <img src="assets/restoration_result4.png" width="400px"/>

#### Face Color Enhancement and Restoration

<img src="assets/color_enhancement_result1.png" width="400px"/> <img src="assets/color_enhancement_result2.png" width="400px"/>

#### Face Inpainting

<img src="assets/inpainting_result1.png" width="400px"/> <img src="assets/inpainting_result2.png" width="400px"/>



### Dependencies and Installation

- Pytorch >= 1.7.1
- CUDA >= 10.1
- Other required packages in `requirements.txt`
```
# git clone this repository
git clone https://github.com/sczhou/CodeFormer
cd CodeFormer

# create new anaconda env
conda create -n codeformer python=3.8 -y
conda activate codeformer

# install python dependencies
pip3 install -r requirements.txt
python basicsr/setup.py develop
conda install -c conda-forge dlib (only for face detection or cropping with dlib)
```
<!-- conda install -c conda-forge dlib -->

### Quick Inference

#### Download Pre-trained Models:
Download the facelib and dlib pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command:
```
python scripts/download_pretrained_models.py facelib
python scripts/download_pretrained_models.py dlib (only for dlib face detector)
```

Download the CodeFormer pretrained models from [[Releases](https://github.com/sczhou/CodeFormer/releases/tag/v0.1.0) | [Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command:
```
python scripts/download_pretrained_models.py CodeFormer
```

#### Prepare Testing Data:
You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. You can get the cropped and aligned faces by running the following command:
```
# you may need to install dlib via: conda install -c conda-forge dlib
python scripts/crop_align_face.py -i [input folder] -o [output folder]
```


#### Testing:
[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.

Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. The results will be saved in the `results` folder.


🧑🏻 Face Restoration (cropped and aligned face)
```
# For cropped and aligned faces (512x512)
python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
```

:framed_picture: Whole Image Enhancement
```
# For whole image
# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
```

:clapper: Video Enhancement
```
# For Windows/Mac users, please install ffmpeg first
conda install -c conda-forge ffmpeg
```
```
# For video clips
# Video path should end with '.mp4'|'.mov'|'.avi'
python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
```

🌈 Face Colorization (cropped and aligned face)
```
# For cropped and aligned faces (512x512)
# Colorize black and white or faded photo
python inference_colorization.py --input_path [image folder]|[image path]
```

🎨 Face Inpainting (cropped and aligned face)
```
# For cropped and aligned faces (512x512)
# Inputs could be masked by white brush using an image editing app (e.g., Photoshop) 
# (check out the examples in inputs/masked_faces)
python inference_inpainting.py --input_path [image folder]|[image path]
```
### Training:
The training commands can be found in the documents: [English](docs/train.md) **|** [简体中文](docs/train_CN.md).

### License

This project is licensed under <a rel="license" href="https://github.com/sczhou/CodeFormer/blob/master/LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.

---
### 🐼 Ecosystem Applications & Deployments

CodeFormer has been widely adopted and deployed across a broad range (>20) of online applications, platforms, API services, and independent websites, and has also been integrated into many open-source projects and toolkits.

> Only demos on **Hugging Face Space**, **Replicate**, and **OpenXLab** are official deployments **maintained by the authors**. All other demos, APIs, apps, websites, and integrations listed below are **third-party (non-official)** and are not affiliated with the CodeFormer authors. Please verify their legitimacy to avoid potential financial loss.


#### Websites (Non-official)

⚠️⚠️⚠️ The following websites are **not official and are not operated by us**. They use our models without any license or authorization. Please verify their legitimacy to avoid potential financial loss.


| Website | Link | Notes |
|---------|------|--------|
| CodeFormer.net | https://codeformer.net/ | Non-official website |
| CodeFormer.cn | https://www.codeformer.cn/ | Non-official website |
| CodeFormerAI.com | https://codeformerai.com/ | Non-official website |

#### Online Demos / API Platforms

| Platform | Link | Notes |
|----------|------|--------|
| Hugging Face | https://huggingface.co/spaces/sczhou/CodeFormer | Maintained by Authors |
| Replicate | https://replicate.com/sczhou/codeformer | Maintained by Authors |
| OpenXLab | https://openxlab.org.cn/apps/detail/ShangchenZhou/CodeFormer |Maintained by Authors |
| Segmind | https://www.segmind.com/models/codeformer | Non-official |
| Sieve | https://www.sievedata.com/functions/sieve/codeformer | Non-official |
| Fal.ai | https://fal.ai/models/fal-ai/codeformer | Non-official |
| VaikerAI | https://vaikerai.com/sczhou/codeformer | Non-official |
| Scade.pro | https://www.scade.pro/processors/lucataco-codeformer | Non-official |
| Grandline | https://www.grandline.ai/model/codeformer | Non-official |
| AI Demos | https://aidemos.com/tools/codeformer | Non-official |
| Synexa | https://synexa.ai/explore/sczhou/codeformer | Non-official |
| RentPrompts | https://rentprompts.ai/models/Codeformer | Non-official |
| ElevaticsAI | https://elevatics.ai/models/super-resolution/codeformer | Non-official |
| Anakin.ai | https://anakin.ai/apps/codeformer-online-face-restoration-by-codeformer-19343 | Non-official |
| Relayto | https://relayto.com/explore/codeformer-yf9rj8kwc7zsr | Non-official |


#### Open-Source Projects & Toolkits

| Project / Toolkit | Link | Notes |
|-------------------|------|--------|
| Stable Diffusion GUI | https://nmkd.itch.io/t2i-gui | Integration |
| Stable Diffusion WebUI | https://github.com/AUTOMATIC1111/stable-diffusion-webui | Integration |
| ChaiNNer | https://github.com/chaiNNer-org/chaiNNer | Integration |
| PyPI | https://pypi.org/project/codeformer/ ; https://pypi.org/project/codeformer-pip/ | Python packages |
| ComfyUI | https://stable-diffusion-art.com/codeformer/ | Integration |

---
### Acknowledgement

This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.

### Citation
If our work is useful for your research, please consider citing:

    @inproceedings{zhou2022codeformer,
        author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
        title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
        booktitle = {NeurIPS},
        year = {2022}
    }


### Contact
If you have any questions, please feel free to reach me out at `shangchenzhou@gmail.com`. 


================================================
FILE: basicsr/VERSION
================================================
1.3.2


================================================
FILE: basicsr/__init__.py
================================================
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .metrics import *
from .models import *
from .ops import *
from .train import *
from .utils import *
from .version import __gitsha__, __version__


================================================
FILE: basicsr/archs/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp

from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY

__all__ = ['build_network']

# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]


def build_network(opt):
    opt = deepcopy(opt)
    network_type = opt.pop('type')
    net = ARCH_REGISTRY.get(network_type)(**opt)
    logger = get_root_logger()
    logger.info(f'Network [{net.__class__.__name__}] is created.')
    return net


================================================
FILE: basicsr/archs/arcface_arch.py
================================================
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY


def conv3x3(inplanes, outplanes, stride=1):
    """A simple wrapper for 3x3 convolution with padding.

    Args:
        inplanes (int): Channel number of inputs.
        outplanes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
    """
    return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    """Basic residual block used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
    """
    expansion = 1  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class IRBlock(nn.Module):
    """Improved residual block (IR Block) used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
    """
    expansion = 1  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super(IRBlock, self).__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        if self.use_se:
            self.se = SEBlock(planes)

    def forward(self, x):
        residual = x
        out = self.bn0(x)
        out = self.conv1(out)
        out = self.bn1(out)
        out = self.prelu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.prelu(out)

        return out


class Bottleneck(nn.Module):
    """Bottleneck block used in the ResNetArcFace architecture.

    Args:
        inplanes (int): Channel number of inputs.
        planes (int): Channel number of outputs.
        stride (int): Stride in convolution. Default: 1.
        downsample (nn.Module): The downsample module. Default: None.
    """
    expansion = 4  # output channel expansion ratio

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBlock(nn.Module):
    """The squeeze-and-excitation block (SEBlock) used in the IRBlock.

    Args:
        channel (int): Channel number of inputs.
        reduction (int): Channel reduction ration. Default: 16.
    """

    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # pool to 1x1 without spatial information
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
            nn.Sigmoid())

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
    """ArcFace with ResNet architectures.

    Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.

    Args:
        block (str): Block used in the ArcFace architecture.
        layers (tuple(int)): Block numbers in each layer.
        use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
    """

    def __init__(self, block, layers, use_se=True):
        if block == 'IRBlock':
            block = IRBlock
        self.inplanes = 64
        self.use_se = use_se
        super(ResNetArcFace, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn4 = nn.BatchNorm2d(512)
        self.dropout = nn.Dropout()
        self.fc5 = nn.Linear(512 * 8 * 8, 512)
        self.bn5 = nn.BatchNorm1d(512)

        # initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for _ in range(1, num_blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.bn4(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc5(x)
        x = self.bn5(x)

        return x

================================================
FILE: basicsr/archs/arch_util.py
================================================
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm

from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger


@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
    """Initialize network weights.

    Args:
        module_list (list[nn.Module] | nn.Module): Modules to be initialized.
        scale (float): Scale initialized weights, especially for residual
            blocks. Default: 1.
        bias_fill (float): The value to fill bias. Default: 0
        kwargs (dict): Other arguments for initialization function.
    """
    if not isinstance(module_list, list):
        module_list = [module_list]
    for module in module_list:
        for m in module.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, **kwargs)
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)
            elif isinstance(m, _BatchNorm):
                init.constant_(m.weight, 1)
                if m.bias is not None:
                    m.bias.data.fill_(bias_fill)


def make_layer(basic_block, num_basic_block, **kwarg):
    """Make layers by stacking the same blocks.

    Args:
        basic_block (nn.module): nn.module class for basic block.
        num_basic_block (int): number of blocks.

    Returns:
        nn.Sequential: Stacked blocks in nn.Sequential.
    """
    layers = []
    for _ in range(num_basic_block):
        layers.append(basic_block(**kwarg))
    return nn.Sequential(*layers)


class ResidualBlockNoBN(nn.Module):
    """Residual block without BN.

    It has a style of:
        ---Conv-ReLU-Conv-+-
         |________________|

    Args:
        num_feat (int): Channel number of intermediate features.
            Default: 64.
        res_scale (float): Residual scale. Default: 1.
        pytorch_init (bool): If set to True, use pytorch default init,
            otherwise, use default_init_weights. Default: False.
    """

    def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
        super(ResidualBlockNoBN, self).__init__()
        self.res_scale = res_scale
        self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
        self.relu = nn.ReLU(inplace=True)

        if not pytorch_init:
            default_init_weights([self.conv1, self.conv2], 0.1)

    def forward(self, x):
        identity = x
        out = self.conv2(self.relu(self.conv1(x)))
        return identity + out * self.res_scale


class Upsample(nn.Sequential):
    """Upsample module.

    Args:
        scale (int): Scale factor. Supported scales: 2^n and 3.
        num_feat (int): Channel number of intermediate features.
    """

    def __init__(self, scale, num_feat):
        m = []
        if (scale & (scale - 1)) == 0:  # scale = 2^n
            for _ in range(int(math.log(scale, 2))):
                m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
                m.append(nn.PixelShuffle(2))
        elif scale == 3:
            m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
            m.append(nn.PixelShuffle(3))
        else:
            raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
        super(Upsample, self).__init__(*m)


def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
    """Warp an image or feature map with optical flow.

    Args:
        x (Tensor): Tensor with size (n, c, h, w).
        flow (Tensor): Tensor with size (n, h, w, 2), normal value.
        interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
        padding_mode (str): 'zeros' or 'border' or 'reflection'.
            Default: 'zeros'.
        align_corners (bool): Before pytorch 1.3, the default value is
            align_corners=True. After pytorch 1.3, the default value is
            align_corners=False. Here, we use the True as default.

    Returns:
        Tensor: Warped image or feature map.
    """
    assert x.size()[-2:] == flow.size()[1:3]
    _, _, h, w = x.size()
    # create mesh grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
    grid = torch.stack((grid_x, grid_y), 2).float()  # W(x), H(y), 2
    grid.requires_grad = False

    vgrid = grid + flow
    # scale grid to [-1,1]
    vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
    vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
    vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
    output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)

    # TODO, what if align_corners=False
    return output


def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
    """Resize a flow according to ratio or shape.

    Args:
        flow (Tensor): Precomputed flow. shape [N, 2, H, W].
        size_type (str): 'ratio' or 'shape'.
        sizes (list[int | float]): the ratio for resizing or the final output
            shape.
            1) The order of ratio should be [ratio_h, ratio_w]. For
            downsampling, the ratio should be smaller than 1.0 (i.e., ratio
            < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
            ratio > 1.0).
            2) The order of output_size should be [out_h, out_w].
        interp_mode (str): The mode of interpolation for resizing.
            Default: 'bilinear'.
        align_corners (bool): Whether align corners. Default: False.

    Returns:
        Tensor: Resized flow.
    """
    _, _, flow_h, flow_w = flow.size()
    if size_type == 'ratio':
        output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
    elif size_type == 'shape':
        output_h, output_w = sizes[0], sizes[1]
    else:
        raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')

    input_flow = flow.clone()
    ratio_h = output_h / flow_h
    ratio_w = output_w / flow_w
    input_flow[:, 0, :, :] *= ratio_w
    input_flow[:, 1, :, :] *= ratio_h
    resized_flow = F.interpolate(
        input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
    return resized_flow


# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
    """ Pixel unshuffle.

    Args:
        x (Tensor): Input feature with shape (b, c, hh, hw).
        scale (int): Downsample ratio.

    Returns:
        Tensor: the pixel unshuffled feature.
    """
    b, c, hh, hw = x.size()
    out_channel = c * (scale**2)
    assert hh % scale == 0 and hw % scale == 0
    h = hh // scale
    w = hw // scale
    x_view = x.view(b, c, h, scale, w, scale)
    return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)


class DCNv2Pack(ModulatedDeformConvPack):
    """Modulated deformable conv for deformable alignment.

    Different from the official DCNv2Pack, which generates offsets and masks
    from the preceding features, this DCNv2Pack takes another different
    features to generate offsets and masks.

    Ref:
        Delving Deep into Deformable Alignment in Video Super-Resolution.
    """

    def forward(self, x, feat):
        out = self.conv_offset(feat)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)

        offset_absmean = torch.mean(torch.abs(offset))
        if offset_absmean > 50:
            logger = get_root_logger()
            logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')

        if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
            return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
                                                 self.dilation, mask)
        else:
            return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
                                         self.dilation, self.groups, self.deformable_groups)


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
            'The distribution of values may be incorrect.',
            stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        low = norm_cdf((a - mean) / std)
        up = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [low, up], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * low - 1, 2 * up - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution.

    From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py

    The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


# From PyTorch
def _ntuple(n):

    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple

================================================
FILE: basicsr/archs/codeformer_arch.py
================================================
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List

from basicsr.archs.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY

def calc_mean_std(feat, eps=1e-5):
    """Calculate mean and std for adaptive_instance_normalization.

    Args:
        feat (Tensor): 4D tensor.
        eps (float): A small value added to the variance to avoid
            divide-by-zero. Default: 1e-5.
    """
    size = feat.size()
    assert len(size) == 4, 'The input feature should be 4D tensor.'
    b, c = size[:2]
    feat_var = feat.view(b, c, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(b, c, 1, 1)
    feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
    return feat_mean, feat_std


def adaptive_instance_normalization(content_feat, style_feat):
    """Adaptive instance normalization.

    Adjust the reference features to have the similar color and illuminations
    as those in the degradate features.

    Args:
        content_feat (Tensor): The reference feature.
        style_feat (Tensor): The degradate features.
    """
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)
    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask=None):
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack(
            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos_y = torch.stack(
            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
        ).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos

def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


class TransformerSALayer(nn.Module):
    def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
        # Implementation of Feedforward model - MLP
        self.linear1 = nn.Linear(embed_dim, dim_mlp)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_mlp, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward(self, tgt,
                tgt_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        
        # self attention
        tgt2 = self.norm1(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)

        # ffn
        tgt2 = self.norm2(tgt)
        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
        tgt = tgt + self.dropout2(tgt2)
        return tgt

class Fuse_sft_block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.encode_enc = ResBlock(2*in_ch, out_ch)

        self.scale = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                    nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))

        self.shift = nn.Sequential(
                    nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
                    nn.LeakyReLU(0.2, True),
                    nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))

    def forward(self, enc_feat, dec_feat, w=1):
        enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
        scale = self.scale(enc_feat)
        shift = self.shift(enc_feat)
        residual = w * (dec_feat * scale + shift)
        out = dec_feat + residual
        return out


@ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder):
    def __init__(self, dim_embd=512, n_head=8, n_layers=9, 
                codebook_size=1024, latent_size=256,
                connect_list=['32', '64', '128', '256'],
                fix_modules=['quantize','generator'], vqgan_path=None):
        super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)

        if vqgan_path is not None:
            self.load_state_dict(
                torch.load(vqgan_path, map_location='cpu')['params_ema'])

        if fix_modules is not None:
            for module in fix_modules:
                for param in getattr(self, module).parameters():
                    param.requires_grad = False

        self.connect_list = connect_list
        self.n_layers = n_layers
        self.dim_embd = dim_embd
        self.dim_mlp = dim_embd*2

        self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
        self.feat_emb = nn.Linear(256, self.dim_embd)

        # transformer
        self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) 
                                    for _ in range(self.n_layers)])

        # logits_predict head
        self.idx_pred_layer = nn.Sequential(
            nn.LayerNorm(dim_embd),
            nn.Linear(dim_embd, codebook_size, bias=False))
        
        self.channels = {
            '16': 512,
            '32': 256,
            '64': 256,
            '128': 128,
            '256': 128,
            '512': 64,
        }

        # after second residual block for > 16, before attn layer for ==16
        self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
        # after first residual block for > 16, before attn layer for ==16
        self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}

        # fuse_convs_dict
        self.fuse_convs_dict = nn.ModuleDict()
        for f_size in self.connect_list:
            in_ch = self.channels[f_size]
            self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
        # ################### Encoder #####################
        enc_feat_dict = {}
        out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
        for i, block in enumerate(self.encoder.blocks):
            x = block(x) 
            if i in out_list:
                enc_feat_dict[str(x.shape[-1])] = x.clone()

        lq_feat = x
        # ################# Transformer ###################
        # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
        pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
        # BCHW -> BC(HW) -> (HW)BC
        feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
        query_emb = feat_emb
        # Transformer encoder
        for layer in self.ft_layers:
            query_emb = layer(query_emb, query_pos=pos_emb)

        # output logits
        logits = self.idx_pred_layer(query_emb) # (hw)bn
        logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n

        if code_only: # for training stage II
          # logits doesn't need softmax before cross_entropy loss
            return logits, lq_feat

        # ################# Quantization ###################
        # if self.training:
        #     quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
        #     # b(hw)c -> bc(hw) -> bchw
        #     quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
        # ------------
        soft_one_hot = F.softmax(logits, dim=2)
        _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
        quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
        # preserve gradients
        # quant_feat = lq_feat + (quant_feat - lq_feat).detach()

        if detach_16:
            quant_feat = quant_feat.detach() # for training stage III
        if adain:
            quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)

        # ################## Generator ####################
        x = quant_feat
        fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]

        for i, block in enumerate(self.generator.blocks):
            x = block(x) 
            if i in fuse_list: # fuse after i-th block
                f_size = str(x.shape[-1])
                if w>0:
                    x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
        out = x
        # logits doesn't need softmax before cross_entropy loss
        return out, logits, lq_feat

================================================
FILE: basicsr/archs/rrdbnet_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F

from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import default_init_weights, make_layer, pixel_unshuffle


class ResidualDenseBlock(nn.Module):
    """Residual Dense Block.

    Used in RRDB block in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat=64, num_grow_ch=32):
        super(ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        # Emperically, we use 0.2 to scale the residual for better performance
        return x5 * 0.2 + x


class RRDB(nn.Module):
    """Residual in Residual Dense Block.

    Used in RRDB-Net in ESRGAN.

    Args:
        num_feat (int): Channel number of intermediate features.
        num_grow_ch (int): Channels for each growth.
    """

    def __init__(self, num_feat, num_grow_ch=32):
        super(RRDB, self).__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        # Emperically, we use 0.2 to scale the residual for better performance
        return out * 0.2 + x


@ARCH_REGISTRY.register()
class RRDBNet(nn.Module):
    """Networks consisting of Residual in Residual Dense Block, which is used
    in ESRGAN.

    ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.

    We extend ESRGAN for scale x2 and scale x1.
    Note: This is one option for scale 1, scale 2 in RRDBNet.
    We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
    and enlarge the channel size before feeding inputs into the main ESRGAN architecture.

    Args:
        num_in_ch (int): Channel number of inputs.
        num_out_ch (int): Channel number of outputs.
        num_feat (int): Channel number of intermediate features.
            Default: 64
        num_block (int): Block number in the trunk network. Defaults: 23
        num_grow_ch (int): Channels for each growth. Default: 32.
    """

    def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
        super(RRDBNet, self).__init__()
        self.scale = scale
        if scale == 2:
            num_in_ch = num_in_ch * 4
        elif scale == 1:
            num_in_ch = num_in_ch * 16
        self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
        self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        # upsample
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        if self.scale == 2:
            feat = pixel_unshuffle(x, scale=2)
        elif self.scale == 1:
            feat = pixel_unshuffle(x, scale=4)
        else:
            feat = x
        feat = self.conv_first(feat)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        # upsample
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

================================================
FILE: basicsr/archs/vgg_arch.py
================================================
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg

from basicsr.utils.registry import ARCH_REGISTRY

VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {
    'vgg11': [
        'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
        'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
        'pool5'
    ],
    'vgg13': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
        'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
    ],
    'vgg16': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
        'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
        'pool5'
    ],
    'vgg19': [
        'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
        'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
        'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
        'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
    ]
}


def insert_bn(names):
    """Insert bn layer after each conv.

    Args:
        names (list): The list of layer names.

    Returns:
        list: The list of layer names with bn layers.
    """
    names_bn = []
    for name in names:
        names_bn.append(name)
        if 'conv' in name:
            position = name.replace('conv', '')
            names_bn.append('bn' + position)
    return names_bn


@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
    """VGG network for feature extraction.

    In this implementation, we allow users to choose whether use normalization
    in the input feature and the type of vgg network. Note that the pretrained
    path must fit the vgg type.

    Args:
        layer_name_list (list[str]): Forward function returns the corresponding
            features according to the layer_name_list.
            Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
        vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
        use_input_norm (bool): If True, normalize the input image. Importantly,
            the input feature must in the range [0, 1]. Default: True.
        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
            Default: False.
        requires_grad (bool): If true, the parameters of VGG network will be
            optimized. Default: False.
        remove_pooling (bool): If true, the max pooling operations in VGG net
            will be removed. Default: False.
        pooling_stride (int): The stride of max pooling operation. Default: 2.
    """

    def __init__(self,
                 layer_name_list,
                 vgg_type='vgg19',
                 use_input_norm=True,
                 range_norm=False,
                 requires_grad=False,
                 remove_pooling=False,
                 pooling_stride=2):
        super(VGGFeatureExtractor, self).__init__()

        self.layer_name_list = layer_name_list
        self.use_input_norm = use_input_norm
        self.range_norm = range_norm

        self.names = NAMES[vgg_type.replace('_bn', '')]
        if 'bn' in vgg_type:
            self.names = insert_bn(self.names)

        # only borrow layers that will be used to avoid unused params
        max_idx = 0
        for v in layer_name_list:
            idx = self.names.index(v)
            if idx > max_idx:
                max_idx = idx

        if os.path.exists(VGG_PRETRAIN_PATH):
            vgg_net = getattr(vgg, vgg_type)(pretrained=False)
            state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
            vgg_net.load_state_dict(state_dict)
        else:
            vgg_net = getattr(vgg, vgg_type)(pretrained=True)

        features = vgg_net.features[:max_idx + 1]

        modified_net = OrderedDict()
        for k, v in zip(self.names, features):
            if 'pool' in k:
                # if remove_pooling is true, pooling operation will be removed
                if remove_pooling:
                    continue
                else:
                    # in some cases, we may want to change the default stride
                    modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
            else:
                modified_net[k] = v

        self.vgg_net = nn.Sequential(modified_net)

        if not requires_grad:
            self.vgg_net.eval()
            for param in self.parameters():
                param.requires_grad = False
        else:
            self.vgg_net.train()
            for param in self.parameters():
                param.requires_grad = True

        if self.use_input_norm:
            # the mean is for image with range [0, 1]
            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            # the std is for image with range [0, 1]
            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (n, c, h, w).

        Returns:
            Tensor: Forward results.
        """
        if self.range_norm:
            x = (x + 1) / 2
        if self.use_input_norm:
            x = (x - self.mean) / self.std
        output = {}

        for key, layer in self.vgg_net._modules.items():
            x = layer(x)
            if key in self.layer_name_list:
                output[key] = x.clone()

        return output


================================================
FILE: basicsr/archs/vqgan_arch.py
================================================
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py

'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY

def normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
    

@torch.jit.script
def swish(x):
    return x*torch.sigmoid(x)


#  Define VQVAE classes
class VectorQuantizer(nn.Module):
    def __init__(self, codebook_size, emb_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.codebook_size = codebook_size  # number of embeddings
        self.emb_dim = emb_dim  # dimension of embedding
        self.beta = beta  # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
        self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)

    def forward(self, z):
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.emb_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
            2 * torch.matmul(z_flattened, self.embedding.weight.t())

        mean_distance = torch.mean(d)
        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
        # [0-1], higher score, higher confidence
        # min_encoding_scores = torch.exp(-min_encoding_scores/10)

        min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
        # compute loss for embedding
        loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q, loss, {
            "perplexity": perplexity,
            "min_encodings": min_encodings,
            "min_encoding_indices": min_encoding_indices,
            "mean_distance": mean_distance
            }

    def get_codebook_feat(self, indices, shape):
        # input indices: batch*token_num -> (batch*token_num)*1
        # shape: batch, height, width, channel
        indices = indices.view(-1,1)
        min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
        min_encodings.scatter_(1, indices, 1)
        # get quantized latent vectors
        z_q = torch.matmul(min_encodings.float(), self.embedding.weight)

        if shape is not None:  # reshape back to match original input shape
            z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()

        return z_q


class GumbelQuantizer(nn.Module):
    def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
        super().__init__()
        self.codebook_size = codebook_size  # number of embeddings
        self.emb_dim = emb_dim  # dimension of embedding
        self.straight_through = straight_through
        self.temperature = temp_init
        self.kl_weight = kl_weight
        self.proj = nn.Conv2d(num_hiddens, codebook_size, 1)  # projects last encoder layer to quantized logits
        self.embed = nn.Embedding(codebook_size, emb_dim)

    def forward(self, z):
        hard = self.straight_through if self.training else True

        logits = self.proj(z)

        soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)

        z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)

        # + kl divergence to the prior loss
        qy = F.softmax(logits, dim=1)
        diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
        min_encoding_indices = soft_one_hot.argmax(dim=1)

        return z_q, diff, {
            "min_encoding_indices": min_encoding_indices
        }


class Downsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
        x = self.conv(x)
        return x


class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
        x = self.conv(x)

        return x


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels=None):
        super(ResBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels if out_channels is None else out_channels
        self.norm1 = normalize(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = normalize(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x_in):
        x = x_in
        x = self.norm1(x)
        x = swish(x)
        x = self.conv1(x)
        x = self.norm2(x)
        x = swish(x)
        x = self.conv2(x)
        if self.in_channels != self.out_channels:
            x_in = self.conv_out(x_in)

        return x + x_in


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = normalize(in_channels)
        self.q = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.k = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.v = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )
        self.proj_out = torch.nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=1,
            stride=1,
            padding=0
        )

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention
        b, c, h, w = q.shape
        q = q.reshape(b, c, h*w)
        q = q.permute(0, 2, 1)   
        k = k.reshape(b, c, h*w)
        w_ = torch.bmm(q, k) 
        w_ = w_ * (int(c)**(-0.5))
        w_ = F.softmax(w_, dim=2)

        # attend to values
        v = v.reshape(b, c, h*w)
        w_ = w_.permute(0, 2, 1) 
        h_ = torch.bmm(v, w_)
        h_ = h_.reshape(b, c, h, w)

        h_ = self.proj_out(h_)

        return x+h_


class Encoder(nn.Module):
    def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
        super().__init__()
        self.nf = nf
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.attn_resolutions = attn_resolutions

        curr_res = self.resolution
        in_ch_mult = (1,)+tuple(ch_mult)

        blocks = []
        # initial convultion
        blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))

        # residual and downsampling blocks, with attention on smaller res (16x16)
        for i in range(self.num_resolutions):
            block_in_ch = nf * in_ch_mult[i]
            block_out_ch = nf * ch_mult[i]
            for _ in range(self.num_res_blocks):
                blocks.append(ResBlock(block_in_ch, block_out_ch))
                block_in_ch = block_out_ch
                if curr_res in attn_resolutions:
                    blocks.append(AttnBlock(block_in_ch))

            if i != self.num_resolutions - 1:
                blocks.append(Downsample(block_in_ch))
                curr_res = curr_res // 2

        # non-local attention block
        blocks.append(ResBlock(block_in_ch, block_in_ch))
        blocks.append(AttnBlock(block_in_ch))
        blocks.append(ResBlock(block_in_ch, block_in_ch))

        # normalise and convert to latent size
        blocks.append(normalize(block_in_ch))
        blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
            
        return x


class Generator(nn.Module):
    def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
        super().__init__()
        self.nf = nf 
        self.ch_mult = ch_mult 
        self.num_resolutions = len(self.ch_mult)
        self.num_res_blocks = res_blocks
        self.resolution = img_size 
        self.attn_resolutions = attn_resolutions
        self.in_channels = emb_dim
        self.out_channels = 3
        block_in_ch = self.nf * self.ch_mult[-1]
        curr_res = self.resolution // 2 ** (self.num_resolutions-1)

        blocks = []
        # initial conv
        blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))

        # non-local attention block
        blocks.append(ResBlock(block_in_ch, block_in_ch))
        blocks.append(AttnBlock(block_in_ch))
        blocks.append(ResBlock(block_in_ch, block_in_ch))

        for i in reversed(range(self.num_resolutions)):
            block_out_ch = self.nf * self.ch_mult[i]

            for _ in range(self.num_res_blocks):
                blocks.append(ResBlock(block_in_ch, block_out_ch))
                block_in_ch = block_out_ch

                if curr_res in self.attn_resolutions:
                    blocks.append(AttnBlock(block_in_ch))

            if i != 0:
                blocks.append(Upsample(block_in_ch))
                curr_res = curr_res * 2

        blocks.append(normalize(block_in_ch))
        blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))

        self.blocks = nn.ModuleList(blocks)
   

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
            
        return x

  
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
    def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
                beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
        super().__init__()
        logger = get_root_logger()
        self.in_channels = 3 
        self.nf = nf 
        self.n_blocks = res_blocks 
        self.codebook_size = codebook_size
        self.embed_dim = emb_dim
        self.ch_mult = ch_mult
        self.resolution = img_size
        self.attn_resolutions = attn_resolutions
        self.quantizer_type = quantizer
        self.encoder = Encoder(
            self.in_channels,
            self.nf,
            self.embed_dim,
            self.ch_mult,
            self.n_blocks,
            self.resolution,
            self.attn_resolutions
        )
        if self.quantizer_type == "nearest":
            self.beta = beta #0.25
            self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
        elif self.quantizer_type == "gumbel":
            self.gumbel_num_hiddens = emb_dim
            self.straight_through = gumbel_straight_through
            self.kl_weight = gumbel_kl_weight
            self.quantize = GumbelQuantizer(
                self.codebook_size,
                self.embed_dim,
                self.gumbel_num_hiddens,
                self.straight_through,
                self.kl_weight
            )
        self.generator = Generator(
            self.nf, 
            self.embed_dim,
            self.ch_mult, 
            self.n_blocks, 
            self.resolution, 
            self.attn_resolutions
        )

        if model_path is not None:
            chkpt = torch.load(model_path, map_location='cpu')
            if 'params_ema' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
                logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
            elif 'params' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
                logger.info(f'vqgan is loaded from: {model_path} [params]')
            else:
                raise ValueError(f'Wrong params!')


    def forward(self, x):
        x = self.encoder(x)
        quant, codebook_loss, quant_stats = self.quantize(x)
        x = self.generator(quant)
        return x, codebook_loss, quant_stats



# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
    def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
        super().__init__()

        layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
        ndf_mult = 1
        ndf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            ndf_mult_prev = ndf_mult
            ndf_mult = min(2 ** n, 8)
            layers += [
                nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
                nn.BatchNorm2d(ndf * ndf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        ndf_mult_prev = ndf_mult
        ndf_mult = min(2 ** n_layers, 8)

        layers += [
            nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(ndf * ndf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        layers += [
            nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)]  # output 1 channel prediction map
        self.main = nn.Sequential(*layers)

        if model_path is not None:
            chkpt = torch.load(model_path, map_location='cpu')
            if 'params_d' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
            elif 'params' in chkpt:
                self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
            else:
                raise ValueError(f'Wrong params!')

    def forward(self, x):
        return self.main(x)

================================================
FILE: basicsr/data/__init__.py
================================================
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp

from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY

__all__ = ['build_dataset', 'build_dataloader']

# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]


def build_dataset(dataset_opt):
    """Build dataset from options.

    Args:
        dataset_opt (dict): Configuration for dataset. It must constain:
            name (str): Dataset name.
            type (str): Dataset type.
    """
    dataset_opt = deepcopy(dataset_opt)
    dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
    logger = get_root_logger()
    logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
    return dataset


def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
    """Build dataloader.

    Args:
        dataset (torch.utils.data.Dataset): Dataset.
        dataset_opt (dict): Dataset options. It contains the following keys:
            phase (str): 'train' or 'val'.
            num_worker_per_gpu (int): Number of workers for each GPU.
            batch_size_per_gpu (int): Training batch size for each GPU.
        num_gpu (int): Number of GPUs. Used only in the train phase.
            Default: 1.
        dist (bool): Whether in distributed training. Used only in the train
            phase. Default: False.
        sampler (torch.utils.data.sampler): Data sampler. Default: None.
        seed (int | None): Seed. Default: None
    """
    phase = dataset_opt['phase']
    rank, _ = get_dist_info()
    if phase == 'train':
        if dist:  # distributed training
            batch_size = dataset_opt['batch_size_per_gpu']
            num_workers = dataset_opt['num_worker_per_gpu']
        else:  # non-distributed training
            multiplier = 1 if num_gpu == 0 else num_gpu
            batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
            num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
        dataloader_args = dict(
            dataset=dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            sampler=sampler,
            drop_last=True)
        if sampler is None:
            dataloader_args['shuffle'] = True
        dataloader_args['worker_init_fn'] = partial(
            worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
    elif phase in ['val', 'test']:  # validation
        dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
    else:
        raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")

    dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)

    prefetch_mode = dataset_opt.get('prefetch_mode')
    if prefetch_mode == 'cpu':  # CPUPrefetcher
        num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
        logger = get_root_logger()
        logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
        return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
    else:
        # prefetch_mode=None: Normal dataloader
        # prefetch_mode='cuda': dataloader for CUDAPrefetcher
        return torch.utils.data.DataLoader(**dataloader_args)


def worker_init_fn(worker_id, num_workers, rank, seed):
    # Set the worker seed to num_workers * rank + worker_id + seed
    worker_seed = num_workers * rank + worker_id + seed
    np.random.seed(worker_seed)
    random.seed(worker_seed)


================================================
FILE: basicsr/data/data_sampler.py
================================================
import math
import torch
from torch.utils.data.sampler import Sampler


class EnlargedSampler(Sampler):
    """Sampler that restricts data loading to a subset of the dataset.

    Modified from torch.utils.data.distributed.DistributedSampler
    Support enlarging the dataset for iteration-based training, for saving
    time when restart the dataloader after each epoch

    Args:
        dataset (torch.utils.data.Dataset): Dataset used for sampling.
        num_replicas (int | None): Number of processes participating in
            the training. It is usually the world_size.
        rank (int | None): Rank of the current process within num_replicas.
        ratio (int): Enlarging ratio. Default: 1.
    """

    def __init__(self, dataset, num_replicas, rank, ratio=1):
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
        self.total_size = self.num_samples * self.num_replicas

    def __iter__(self):
        # deterministically shuffle based on epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        indices = torch.randperm(self.total_size, generator=g).tolist()

        dataset_size = len(self.dataset)
        indices = [v % dataset_size for v in indices]

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


================================================
FILE: basicsr/data/data_util.py
================================================
import cv2
import math
import numpy as np
import torch
from os import path as osp
from PIL import Image, ImageDraw
from torch.nn import functional as F

from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir


def read_img_seq(path, require_mod_crop=False, scale=1):
    """Read a sequence of images from a given folder path.

    Args:
        path (list[str] | str): List of image paths or image folder path.
        require_mod_crop (bool): Require mod crop for each image.
            Default: False.
        scale (int): Scale factor for mod_crop. Default: 1.

    Returns:
        Tensor: size (t, c, h, w), RGB, [0, 1].
    """
    if isinstance(path, list):
        img_paths = path
    else:
        img_paths = sorted(list(scandir(path, full_path=True)))
    imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
    if require_mod_crop:
        imgs = [mod_crop(img, scale) for img in imgs]
    imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
    imgs = torch.stack(imgs, dim=0)
    return imgs


def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
    """Generate an index list for reading `num_frames` frames from a sequence
    of images.

    Args:
        crt_idx (int): Current center index.
        max_frame_num (int): Max number of the sequence of images (from 1).
        num_frames (int): Reading num_frames frames.
        padding (str): Padding mode, one of
            'replicate' | 'reflection' | 'reflection_circle' | 'circle'
            Examples: current_idx = 0, num_frames = 5
            The generated frame indices under different padding mode:
            replicate: [0, 0, 0, 1, 2]
            reflection: [2, 1, 0, 1, 2]
            reflection_circle: [4, 3, 0, 1, 2]
            circle: [3, 4, 0, 1, 2]

    Returns:
        list[int]: A list of indices.
    """
    assert num_frames % 2 == 1, 'num_frames should be an odd number.'
    assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'

    max_frame_num = max_frame_num - 1  # start from 0
    num_pad = num_frames // 2

    indices = []
    for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
        if i < 0:
            if padding == 'replicate':
                pad_idx = 0
            elif padding == 'reflection':
                pad_idx = -i
            elif padding == 'reflection_circle':
                pad_idx = crt_idx + num_pad - i
            else:
                pad_idx = num_frames + i
        elif i > max_frame_num:
            if padding == 'replicate':
                pad_idx = max_frame_num
            elif padding == 'reflection':
                pad_idx = max_frame_num * 2 - i
            elif padding == 'reflection_circle':
                pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
            else:
                pad_idx = i - num_frames
        else:
            pad_idx = i
        indices.append(pad_idx)
    return indices


def paired_paths_from_lmdb(folders, keys):
    """Generate paired paths from lmdb files.

    Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:

    lq.lmdb
    ├── data.mdb
    ├── lock.mdb
    ├── meta_info.txt

    The data.mdb and lock.mdb are standard lmdb files and you can refer to
    https://lmdb.readthedocs.io/en/release/ for more details.

    The meta_info.txt is a specified txt file to record the meta information
    of our datasets. It will be automatically created when preparing
    datasets by our provided dataset tools.
    Each line in the txt file records
    1)image name (with extension),
    2)image shape,
    3)compression level, separated by a white space.
    Example: `baboon.png (120,125,3) 1`

    We use the image name without extension as the lmdb key.
    Note that we use the same key for the corresponding lq and gt images.

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
            Note that this key is different from lmdb keys.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
        raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
                         f'formats. But received {input_key}: {input_folder}; '
                         f'{gt_key}: {gt_folder}')
    # ensure that the two meta_info files are the same
    with open(osp.join(input_folder, 'meta_info.txt')) as fin:
        input_lmdb_keys = [line.split('.')[0] for line in fin]
    with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
        gt_lmdb_keys = [line.split('.')[0] for line in fin]
    if set(input_lmdb_keys) != set(gt_lmdb_keys):
        raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
    else:
        paths = []
        for lmdb_key in sorted(input_lmdb_keys):
            paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
        return paths


def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
    """Generate paired paths from an meta information file.

    Each line in the meta information file contains the image names and
    image shape (usually for gt), separated by a white space.

    Example of an meta information file:
    ```
    0001_s001.png (480,480,3)
    0001_s002.png (480,480,3)
    ```

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
        meta_info_file (str): Path to the meta information file.
        filename_tmpl (str): Template for each filename. Note that the
            template excludes the file extension. Usually the filename_tmpl is
            for files in the input folder.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    with open(meta_info_file, 'r') as fin:
        gt_names = [line.split(' ')[0] for line in fin]

    paths = []
    for gt_name in gt_names:
        basename, ext = osp.splitext(osp.basename(gt_name))
        input_name = f'{filename_tmpl.format(basename)}{ext}'
        input_path = osp.join(input_folder, input_name)
        gt_path = osp.join(gt_folder, gt_name)
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
    return paths


def paired_paths_from_folder(folders, keys, filename_tmpl):
    """Generate paired paths from folders.

    Args:
        folders (list[str]): A list of folder path. The order of list should
            be [input_folder, gt_folder].
        keys (list[str]): A list of keys identifying folders. The order should
            be in consistent with folders, e.g., ['lq', 'gt'].
        filename_tmpl (str): Template for each filename. Note that the
            template excludes the file extension. Usually the filename_tmpl is
            for files in the input folder.

    Returns:
        list[str]: Returned path list.
    """
    assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
                               f'But got {len(folders)}')
    assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
    input_folder, gt_folder = folders
    input_key, gt_key = keys

    input_paths = list(scandir(input_folder))
    gt_paths = list(scandir(gt_folder))
    assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
                                               f'{len(input_paths)}, {len(gt_paths)}.')
    paths = []
    for gt_path in gt_paths:
        basename, ext = osp.splitext(osp.basename(gt_path))
        input_name = f'{filename_tmpl.format(basename)}{ext}'
        input_path = osp.join(input_folder, input_name)
        assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
        gt_path = osp.join(gt_folder, gt_path)
        paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
    return paths


def paths_from_folder(folder):
    """Generate paths from folder.

    Args:
        folder (str): Folder path.

    Returns:
        list[str]: Returned path list.
    """

    paths = list(scandir(folder))
    paths = [osp.join(folder, path) for path in paths]
    return paths


def paths_from_lmdb(folder):
    """Generate paths from lmdb.

    Args:
        folder (str): Folder path.

    Returns:
        list[str]: Returned path list.
    """
    if not folder.endswith('.lmdb'):
        raise ValueError(f'Folder {folder}folder should in lmdb format.')
    with open(osp.join(folder, 'meta_info.txt')) as fin:
        paths = [line.split('.')[0] for line in fin]
    return paths


def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
    """Generate Gaussian kernel used in `duf_downsample`.

    Args:
        kernel_size (int): Kernel size. Default: 13.
        sigma (float): Sigma of the Gaussian kernel. Default: 1.6.

    Returns:
        np.array: The Gaussian kernel.
    """
    from scipy.ndimage import filters as filters
    kernel = np.zeros((kernel_size, kernel_size))
    # set element at the middle to one, a dirac delta
    kernel[kernel_size // 2, kernel_size // 2] = 1
    # gaussian-smooth the dirac, resulting in a gaussian filter
    return filters.gaussian_filter(kernel, sigma)


def duf_downsample(x, kernel_size=13, scale=4):
    """Downsamping with Gaussian kernel used in the DUF official code.

    Args:
        x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
        kernel_size (int): Kernel size. Default: 13.
        scale (int): Downsampling factor. Supported scale: (2, 3, 4).
            Default: 4.

    Returns:
        Tensor: DUF downsampled frames.
    """
    assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'

    squeeze_flag = False
    if x.ndim == 4:
        squeeze_flag = True
        x = x.unsqueeze(0)
    b, t, c, h, w = x.size()
    x = x.view(-1, 1, h, w)
    pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
    x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')

    gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
    gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
    x = F.conv2d(x, gaussian_filter, stride=scale)
    x = x[:, :, 2:-2, 2:-2]
    x = x.view(b, t, c, x.size(2), x.size(3))
    if squeeze_flag:
        x = x.squeeze(0)
    return x


def brush_stroke_mask(img, color=(255,255,255)):
    min_num_vertex = 8
    max_num_vertex = 28
    mean_angle = 2*math.pi / 5
    angle_range = 2*math.pi / 12
    # training large mask ratio (training setting)
    min_width = 30
    max_width = 70
    # very large mask ratio (test setting and refine after 200k)
    # min_width = 80
    # max_width = 120
    def generate_mask(H, W, img=None):
        average_radius = math.sqrt(H*H+W*W) / 8
        mask = Image.new('RGB', (W, H), 0)
        if img is not None: mask = img # Image.fromarray(img)

        for _ in range(np.random.randint(1, 4)):
            num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
            angle_min = mean_angle - np.random.uniform(0, angle_range)
            angle_max = mean_angle + np.random.uniform(0, angle_range)
            angles = []
            vertex = []
            for i in range(num_vertex):
                if i % 2 == 0:
                    angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
                else:
                    angles.append(np.random.uniform(angle_min, angle_max))

            h, w = mask.size
            vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
            for i in range(num_vertex):
                r = np.clip(
                    np.random.normal(loc=average_radius, scale=average_radius//2),
                    0, 2*average_radius)
                new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
                new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
                vertex.append((int(new_x), int(new_y)))

            draw = ImageDraw.Draw(mask)
            width = int(np.random.uniform(min_width, max_width))
            draw.line(vertex, fill=color, width=width)
            for v in vertex:
                draw.ellipse((v[0] - width//2,
                              v[1] - width//2,
                              v[0] + width//2,
                              v[1] + width//2),
                             fill=color)

        return mask

    width, height = img.size
    mask = generate_mask(height, width, img)
    return mask


def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
    """Generate a random free form mask with configuration.
    Args:
        config: Config should have configuration including IMG_SHAPES,
            VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
    Returns:
        tuple: (top, left, height, width)
    Link:
        https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
    """
    height = shape[0]
    width = shape[1]
    mask = np.zeros((height, width), np.float32)
    times = np.random.randint(times-5, times)
    for i in range(times):
        start_x = np.random.randint(width)
        start_y = np.random.randint(height)
        for j in range(1 + np.random.randint(5)):
            angle = 0.01 + np.random.randint(max_angle)
            if i % 2 == 0:
                angle = 2 * 3.1415926 - angle
            length = 10 + np.random.randint(max_len-20, max_len)
            brush_w = 5 + np.random.randint(max_width-30, max_width)
            end_x = (start_x + length * np.sin(angle)).astype(np.int32)
            end_y = (start_y + length * np.cos(angle)).astype(np.int32)
            cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
            start_x, start_y = end_x, end_y
    return mask.astype(np.float32)

================================================
FILE: basicsr/data/ffhq_blind_dataset.py
================================================
import cv2
import math
import random
import numpy as np
import os.path as osp
from scipy.io import loadmat
from PIL import Image
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, 
                                        adjust_hue, adjust_saturation, normalize)
from basicsr.data import gaussian_kernels as gaussian_kernels
from basicsr.data.transforms import augment
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, random_ff_mask
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY

@DATASET_REGISTRY.register()
class FFHQBlindDataset(data.Dataset):

    def __init__(self, opt):
        super(FFHQBlindDataset, self).__init__()
        logger = get_root_logger()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']

        self.gt_folder = opt['dataroot_gt']
        self.gt_size = opt.get('gt_size', 512)
        self.in_size = opt.get('in_size', 512)
        assert self.gt_size >= self.in_size, 'Wrong setting.'
        
        self.mean = opt.get('mean', [0.5, 0.5, 0.5])
        self.std = opt.get('std', [0.5, 0.5, 0.5])

        self.component_path = opt.get('component_path', None)
        self.latent_gt_path = opt.get('latent_gt_path', None)

        if self.component_path is not None:
            self.crop_components = True
            self.components_dict = torch.load(self.component_path)
            self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
            self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
            self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
        else:
            self.crop_components = False

        if self.latent_gt_path is not None:
            self.load_latent_gt = True            
            self.latent_gt_dict = torch.load(self.latent_gt_path)
        else:
            self.load_latent_gt = False  

        if self.io_backend_opt['type'] == 'lmdb':
            self.io_backend_opt['db_paths'] = self.gt_folder
            if not self.gt_folder.endswith('.lmdb'):
                raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
                self.paths = [line.split('.')[0] for line in fin]
        else:
            self.paths = paths_from_folder(self.gt_folder)

        # inpainting mask
        self.gen_inpaint_mask = opt.get('gen_inpaint_mask', False)
        if self.gen_inpaint_mask:
            logger.info(f'generate mask ...')
            # self.mask_max_angle = opt.get('mask_max_angle', 10)
            # self.mask_max_len = opt.get('mask_max_len', 150)
            # self.mask_max_width = opt.get('mask_max_width', 50)
            # self.mask_draw_times = opt.get('mask_draw_times', 10)
            # # print
            # logger.info(f'mask_max_angle: {self.mask_max_angle}')
            # logger.info(f'mask_max_len: {self.mask_max_len}')
            # logger.info(f'mask_max_width: {self.mask_max_width}')
            # logger.info(f'mask_draw_times: {self.mask_draw_times}')

        # perform corrupt
        self.use_corrupt = opt.get('use_corrupt', True)
        self.use_motion_kernel = False
        # self.use_motion_kernel = opt.get('use_motion_kernel', True)

        if self.use_motion_kernel:
            self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
            motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
            self.motion_kernels = torch.load(motion_kernel_path)

        if self.use_corrupt and not self.gen_inpaint_mask:
            # degradation configurations
            self.blur_kernel_size = opt['blur_kernel_size']
            self.blur_sigma = opt['blur_sigma']
            self.kernel_list = opt['kernel_list']
            self.kernel_prob = opt['kernel_prob']
            self.downsample_range = opt['downsample_range']
            self.noise_range = opt['noise_range']
            self.jpeg_range = opt['jpeg_range']
            # print
            logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
            logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
            logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
            logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')

        # color jitter
        self.color_jitter_prob = opt.get('color_jitter_prob', None)
        self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
        self.color_jitter_shift = opt.get('color_jitter_shift', 20)
        if self.color_jitter_prob is not None:
            logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')

        # to gray
        self.gray_prob = opt.get('gray_prob', 0.0)
        if self.gray_prob is not None:
            logger.info(f'Use random gray. Prob: {self.gray_prob}')
        self.color_jitter_shift /= 255.

    @staticmethod
    def color_jitter(img, shift):
        """jitter color: randomly jitter the RGB values, in numpy formats"""
        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
        img = img + jitter_val
        img = np.clip(img, 0, 1)
        return img

    @staticmethod
    def color_jitter_pt(img, brightness, contrast, saturation, hue):
        """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and brightness is not None:
                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
                img = adjust_brightness(img, brightness_factor)

            if fn_id == 1 and contrast is not None:
                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
                img = adjust_contrast(img, contrast_factor)

            if fn_id == 2 and saturation is not None:
                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
                img = adjust_saturation(img, saturation_factor)

            if fn_id == 3 and hue is not None:
                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
                img = adjust_hue(img, hue_factor)
        return img


    def get_component_locations(self, name, status):
        components_bbox = self.components_dict[name]
        if status[0]:  # hflip
            # exchange right and left eye
            tmp = components_bbox['left_eye']
            components_bbox['left_eye'] = components_bbox['right_eye']
            components_bbox['right_eye'] = tmp
            # modify the width coordinate
            components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
            components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
            components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
            components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
        
        locations_gt = {}
        locations_in = {}
        for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
            mean = components_bbox[part][0:2]
            half_len = components_bbox[part][2]
            if 'eye' in part:
                half_len *= self.eye_enlarge_ratio
            elif part == 'nose':
                half_len *= self.nose_enlarge_ratio
            elif part == 'mouth':
                half_len *= self.mouth_enlarge_ratio
            loc = np.hstack((mean - half_len + 1, mean + half_len))
            loc = torch.from_numpy(loc).float()
            locations_gt[part] = loc
            loc_in = loc/(self.gt_size//self.in_size)
            locations_in[part] = loc_in
        return locations_gt, locations_in


    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        # load gt image
        gt_path = self.paths[index]
        name = osp.basename(gt_path)[:-4]
        img_bytes = self.file_client.get(gt_path)
        img_gt = imfrombytes(img_bytes, float32=True)
        
        # random horizontal flip
        img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)

        if self.load_latent_gt:
            if status[0]:
                latent_gt = self.latent_gt_dict['hflip'][name]
            else:
                latent_gt = self.latent_gt_dict['orig'][name]

        if self.crop_components:
            locations_gt, locations_in = self.get_component_locations(name, status)

        # generate in image
        img_in = img_gt
        if self.use_corrupt and not self.gen_inpaint_mask:
            # motion blur
            if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
                m_i = random.randint(0,31)
                k = self.motion_kernels[f'{m_i:02d}']
                img_in = cv2.filter2D(img_in,-1,k)
            
            # gaussian blur
            kernel = gaussian_kernels.random_mixed_kernels(
                self.kernel_list,
                self.kernel_prob,
                self.blur_kernel_size,
                self.blur_sigma,
                self.blur_sigma, 
                [-math.pi, math.pi],
                noise_range=None)
            img_in = cv2.filter2D(img_in, -1, kernel)

            # downsample
            scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
            img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)

            # noise
            if self.noise_range is not None:
                noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
                noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
                img_in = img_in + noise
                img_in = np.clip(img_in, 0, 1)

            # jpeg
            if self.jpeg_range is not None:
                jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
                encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
                _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
                img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.

            # resize to in_size
            img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)

        # if self.gen_inpaint_mask:
        #     inpaint_mask = random_ff_mask(shape=(self.gt_size,self.gt_size), 
        #         max_angle = self.mask_max_angle, max_len = self.mask_max_len, 
        #         max_width = self.mask_max_width, times = self.mask_draw_times)
        #     img_in = img_in * (1 - inpaint_mask.reshape(self.gt_size,self.gt_size,1)) + \
        #              1.0 * inpaint_mask.reshape(self.gt_size,self.gt_size,1)

        #     inpaint_mask = torch.from_numpy(inpaint_mask).view(1,self.gt_size,self.gt_size)

        if self.gen_inpaint_mask:
            img_in = (img_in*255).astype('uint8')
            img_in = brush_stroke_mask(Image.fromarray(img_in))
            img_in = np.array(img_in) / 255.

        # random color jitter (only for lq)
        if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
            img_in = self.color_jitter(img_in, self.color_jitter_shift)
        # random to gray (only for lq)
        if self.gray_prob and np.random.uniform() < self.gray_prob:
            img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
            img_in = np.tile(img_in[:, :, None], [1, 1, 3])

        # BGR to RGB, HWC to CHW, numpy to tensor
        img_in, img_gt = img2tensor([img_in, img_gt], bgr2rgb=True, float32=True)

        # random color jitter (pytorch version) (only for lq)
        if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
            brightness = self.opt.get('brightness', (0.5, 1.5))
            contrast = self.opt.get('contrast', (0.5, 1.5))
            saturation = self.opt.get('saturation', (0, 1.5))
            hue = self.opt.get('hue', (-0.1, 0.1))
            img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)

        # round and clip
        img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.

        # Set vgg range_norm=True if use the normalization here
        # normalize
        normalize(img_in, self.mean, self.std, inplace=True)
        normalize(img_gt, self.mean, self.std, inplace=True)

        return_dict = {'in': img_in, 'gt': img_gt, 'gt_path': gt_path}

        if self.crop_components:
            return_dict['locations_in'] = locations_in
            return_dict['locations_gt'] = locations_gt

        if self.load_latent_gt:
            return_dict['latent_gt'] = latent_gt

        # if self.gen_inpaint_mask:
        #     return_dict['inpaint_mask'] = inpaint_mask

        return return_dict


    def __len__(self):
        return len(self.paths)

================================================
FILE: basicsr/data/ffhq_blind_joint_dataset.py
================================================
import cv2
import math
import random
import numpy as np
import os.path as osp
from scipy.io import loadmat
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, 
                                        adjust_hue, adjust_saturation, normalize)
from basicsr.data import gaussian_kernels as gaussian_kernels
from basicsr.data.transforms import augment
from basicsr.data.data_util import paths_from_folder
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY

@DATASET_REGISTRY.register()
class FFHQBlindJointDataset(data.Dataset):

    def __init__(self, opt):
        super(FFHQBlindJointDataset, self).__init__()
        logger = get_root_logger()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']

        self.gt_folder = opt['dataroot_gt']
        self.gt_size = opt.get('gt_size', 512)
        self.in_size = opt.get('in_size', 512)
        assert self.gt_size >= self.in_size, 'Wrong setting.'
        
        self.mean = opt.get('mean', [0.5, 0.5, 0.5])
        self.std = opt.get('std', [0.5, 0.5, 0.5])

        self.component_path = opt.get('component_path', None)
        self.latent_gt_path = opt.get('latent_gt_path', None)

        if self.component_path is not None:
            self.crop_components = True
            self.components_dict = torch.load(self.component_path)
            self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1.4)
            self.nose_enlarge_ratio = opt.get('nose_enlarge_ratio', 1.1)
            self.mouth_enlarge_ratio = opt.get('mouth_enlarge_ratio', 1.3)
        else:
            self.crop_components = False

        if self.latent_gt_path is not None:
            self.load_latent_gt = True            
            self.latent_gt_dict = torch.load(self.latent_gt_path)
        else:
            self.load_latent_gt = False  

        if self.io_backend_opt['type'] == 'lmdb':
            self.io_backend_opt['db_paths'] = self.gt_folder
            if not self.gt_folder.endswith('.lmdb'):
                raise ValueError("'dataroot_gt' should end with '.lmdb', "f'but received {self.gt_folder}')
            with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
                self.paths = [line.split('.')[0] for line in fin]
        else:
            self.paths = paths_from_folder(self.gt_folder)

        # perform corrupt
        self.use_corrupt = opt.get('use_corrupt', True)
        self.use_motion_kernel = False
        # self.use_motion_kernel = opt.get('use_motion_kernel', True)

        if self.use_motion_kernel:
            self.motion_kernel_prob = opt.get('motion_kernel_prob', 0.001)
            motion_kernel_path = opt.get('motion_kernel_path', 'basicsr/data/motion-blur-kernels-32.pth')
            self.motion_kernels = torch.load(motion_kernel_path)

        if self.use_corrupt:
            # degradation configurations
            self.blur_kernel_size = self.opt['blur_kernel_size']
            self.kernel_list = self.opt['kernel_list']
            self.kernel_prob = self.opt['kernel_prob']
            # Small degradation
            self.blur_sigma = self.opt['blur_sigma']
            self.downsample_range = self.opt['downsample_range']
            self.noise_range = self.opt['noise_range']
            self.jpeg_range = self.opt['jpeg_range']
            # Large degradation
            self.blur_sigma_large = self.opt['blur_sigma_large']
            self.downsample_range_large = self.opt['downsample_range_large']
            self.noise_range_large = self.opt['noise_range_large']
            self.jpeg_range_large = self.opt['jpeg_range_large']

            # print
            logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]')
            logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]')
            logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
            logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]')

        # color jitter
        self.color_jitter_prob = opt.get('color_jitter_prob', None)
        self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob', None)
        self.color_jitter_shift = opt.get('color_jitter_shift', 20)
        if self.color_jitter_prob is not None:
            logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}')

        # to gray
        self.gray_prob = opt.get('gray_prob', 0.0)
        if self.gray_prob is not None:
            logger.info(f'Use random gray. Prob: {self.gray_prob}')
        self.color_jitter_shift /= 255.

    @staticmethod
    def color_jitter(img, shift):
        """jitter color: randomly jitter the RGB values, in numpy formats"""
        jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
        img = img + jitter_val
        img = np.clip(img, 0, 1)
        return img

    @staticmethod
    def color_jitter_pt(img, brightness, contrast, saturation, hue):
        """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats"""
        fn_idx = torch.randperm(4)
        for fn_id in fn_idx:
            if fn_id == 0 and brightness is not None:
                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
                img = adjust_brightness(img, brightness_factor)

            if fn_id == 1 and contrast is not None:
                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
                img = adjust_contrast(img, contrast_factor)

            if fn_id == 2 and saturation is not None:
                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
                img = adjust_saturation(img, saturation_factor)

            if fn_id == 3 and hue is not None:
                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
                img = adjust_hue(img, hue_factor)
        return img


    def get_component_locations(self, name, status):
        components_bbox = self.components_dict[name]
        if status[0]:  # hflip
            # exchange right and left eye
            tmp = components_bbox['left_eye']
            components_bbox['left_eye'] = components_bbox['right_eye']
            components_bbox['right_eye'] = tmp
            # modify the width coordinate
            components_bbox['left_eye'][0] = self.gt_size - components_bbox['left_eye'][0]
            components_bbox['right_eye'][0] = self.gt_size - components_bbox['right_eye'][0]
            components_bbox['nose'][0] = self.gt_size - components_bbox['nose'][0]
            components_bbox['mouth'][0] = self.gt_size - components_bbox['mouth'][0]
        
        locations_gt = {}
        locations_in = {}
        for part in ['left_eye', 'right_eye', 'nose', 'mouth']:
            mean = components_bbox[part][0:2]
            half_len = components_bbox[part][2]
            if 'eye' in part:
                half_len *= self.eye_enlarge_ratio
            elif part == 'nose':
                half_len *= self.nose_enlarge_ratio
            elif part == 'mouth':
                half_len *= self.mouth_enlarge_ratio
            loc = np.hstack((mean - half_len + 1, mean + half_len))
            loc = torch.from_numpy(loc).float()
            locations_gt[part] = loc
            loc_in = loc/(self.gt_size//self.in_size)
            locations_in[part] = loc_in
        return locations_gt, locations_in


    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        # load gt image
        gt_path = self.paths[index]
        name = osp.basename(gt_path)[:-4]
        img_bytes = self.file_client.get(gt_path)
        img_gt = imfrombytes(img_bytes, float32=True)
        
        # random horizontal flip
        img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True)

        if self.load_latent_gt:
            if status[0]:
                latent_gt = self.latent_gt_dict['hflip'][name]
            else:
                latent_gt = self.latent_gt_dict['orig'][name]

        if self.crop_components:
            locations_gt, locations_in = self.get_component_locations(name, status)

        # generate in image
        img_in = img_gt
        if self.use_corrupt:
            # motion blur
            if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
                m_i = random.randint(0,31)
                k = self.motion_kernels[f'{m_i:02d}']
                img_in = cv2.filter2D(img_in,-1,k)
            
            # gaussian blur
            kernel = gaussian_kernels.random_mixed_kernels(
                self.kernel_list,
                self.kernel_prob,
                self.blur_kernel_size,
                self.blur_sigma,
                self.blur_sigma, 
                [-math.pi, math.pi],
                noise_range=None)
            img_in = cv2.filter2D(img_in, -1, kernel)

            # downsample
            scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1])
            img_in = cv2.resize(img_in, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)

            # noise
            if self.noise_range is not None:
                noise_sigma = np.random.uniform(self.noise_range[0] / 255., self.noise_range[1] / 255.)
                noise = np.float32(np.random.randn(*(img_in.shape))) * noise_sigma
                img_in = img_in + noise
                img_in = np.clip(img_in, 0, 1)

            # jpeg
            if self.jpeg_range is not None:
                jpeg_p = np.random.uniform(self.jpeg_range[0], self.jpeg_range[1])
                encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
                _, encimg = cv2.imencode('.jpg', img_in * 255., encode_param)
                img_in = np.float32(cv2.imdecode(encimg, 1)) / 255.

            # resize to in_size
            img_in = cv2.resize(img_in, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)


        # generate in_large with large degradation
        img_in_large = img_gt

        if self.use_corrupt:
            # motion blur
            if self.use_motion_kernel and random.random() < self.motion_kernel_prob:
                m_i = random.randint(0,31)
                k = self.motion_kernels[f'{m_i:02d}']
                img_in_large = cv2.filter2D(img_in_large,-1,k)
            
            # gaussian blur
            kernel = gaussian_kernels.random_mixed_kernels(
                self.kernel_list,
                self.kernel_prob,
                self.blur_kernel_size,
                self.blur_sigma_large,
                self.blur_sigma_large, 
                [-math.pi, math.pi],
                noise_range=None)
            img_in_large = cv2.filter2D(img_in_large, -1, kernel)

            # downsample
            scale = np.random.uniform(self.downsample_range_large[0], self.downsample_range_large[1])
            img_in_large = cv2.resize(img_in_large, (int(self.gt_size // scale), int(self.gt_size // scale)), interpolation=cv2.INTER_LINEAR)

            # noise
            if self.noise_range_large is not None:
                noise_sigma = np.random.uniform(self.noise_range_large[0] / 255., self.noise_range_large[1] / 255.)
                noise = np.float32(np.random.randn(*(img_in_large.shape))) * noise_sigma
                img_in_large = img_in_large + noise
                img_in_large = np.clip(img_in_large, 0, 1)

            # jpeg
            if self.jpeg_range_large is not None:
                jpeg_p = np.random.uniform(self.jpeg_range_large[0], self.jpeg_range_large[1])
                encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(jpeg_p)]
                _, encimg = cv2.imencode('.jpg', img_in_large * 255., encode_param)
                img_in_large = np.float32(cv2.imdecode(encimg, 1)) / 255.

            # resize to in_size
            img_in_large = cv2.resize(img_in_large, (self.in_size, self.in_size), interpolation=cv2.INTER_LINEAR)


        # random color jitter (only for lq)
        if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob):
            img_in = self.color_jitter(img_in, self.color_jitter_shift)
            img_in_large = self.color_jitter(img_in_large, self.color_jitter_shift)
        # random to gray (only for lq)
        if self.gray_prob and np.random.uniform() < self.gray_prob:
            img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2GRAY)
            img_in = np.tile(img_in[:, :, None], [1, 1, 3])
            img_in_large = cv2.cvtColor(img_in_large, cv2.COLOR_BGR2GRAY)
            img_in_large = np.tile(img_in_large[:, :, None], [1, 1, 3])

        # BGR to RGB, HWC to CHW, numpy to tensor
        img_in, img_in_large, img_gt = img2tensor([img_in, img_in_large, img_gt], bgr2rgb=True, float32=True)

        # random color jitter (pytorch version) (only for lq)
        if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob):
            brightness = self.opt.get('brightness', (0.5, 1.5))
            contrast = self.opt.get('contrast', (0.5, 1.5))
            saturation = self.opt.get('saturation', (0, 1.5))
            hue = self.opt.get('hue', (-0.1, 0.1))
            img_in = self.color_jitter_pt(img_in, brightness, contrast, saturation, hue)
            img_in_large = self.color_jitter_pt(img_in_large, brightness, contrast, saturation, hue)

        # round and clip
        img_in = np.clip((img_in * 255.0).round(), 0, 255) / 255.
        img_in_large = np.clip((img_in_large * 255.0).round(), 0, 255) / 255.

        # Set vgg range_norm=True if use the normalization here
        # normalize
        normalize(img_in, self.mean, self.std, inplace=True)
        normalize(img_in_large, self.mean, self.std, inplace=True)
        normalize(img_gt, self.mean, self.std, inplace=True)

        return_dict = {'in': img_in, 'in_large_de': img_in_large, 'gt': img_gt, 'gt_path': gt_path}

        if self.crop_components:
            return_dict['locations_in'] = locations_in
            return_dict['locations_gt'] = locations_gt

        if self.load_latent_gt:
            return_dict['latent_gt'] = latent_gt

        return return_dict


    def __len__(self):
        return len(self.paths)


================================================
FILE: basicsr/data/gaussian_kernels.py
================================================
import math
import numpy as np
import random
from scipy.ndimage.interpolation import shift
from scipy.stats import multivariate_normal


def sigma_matrix2(sig_x, sig_y, theta):
    """Calculate the rotated sigma matrix (two dimensional matrix).
    Args:
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
    Returns:
        ndarray: Rotated sigma matrix.
    """
    D = np.array([[sig_x**2, 0], [0, sig_y**2]])
    U = np.array([[np.cos(theta), -np.sin(theta)],
                  [np.sin(theta), np.cos(theta)]])
    return np.dot(U, np.dot(D, U.T))


def mesh_grid(kernel_size):
    """Generate the mesh grid, centering at zero.
    Args:
        kernel_size (int):
    Returns:
        xy (ndarray): with the shape (kernel_size, kernel_size, 2)
        xx (ndarray): with the shape (kernel_size, kernel_size)
        yy (ndarray): with the shape (kernel_size, kernel_size)
    """
    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
                    yy.reshape(kernel_size * kernel_size,
                               1))).reshape(kernel_size, kernel_size, 2)
    return xy, xx, yy


def pdf2(sigma_matrix, grid):
    """Calculate PDF of the bivariate Gaussian distribution.
    Args:
        sigma_matrix (ndarray): with the shape (2, 2)
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.
    Returns:
        kernel (ndarrray): un-normalized kernel.
    """
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
    return kernel


def cdf2(D, grid):
    """Calculate the CDF of the standard bivariate Gaussian distribution.
        Used in skewed Gaussian distribution.
    Args:
        D (ndarrasy): skew matrix.
        grid (ndarray): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size.
    Returns:
        cdf (ndarray): skewed cdf.
    """
    rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
    grid = np.dot(grid, D)
    cdf = rv.cdf(grid)
    return cdf


def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
    """Generate a bivariate skew Gaussian kernel.
        Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        D (ndarrasy): skew matrix.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    .. _A multivariate skew normal distribution:
        https://www.sciencedirect.com/science/article/pii/S0047259X03001313
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    pdf = pdf2(sigma_matrix, grid)
    cdf = cdf2(D, grid)
    kernel = pdf * cdf
    kernel = kernel / np.sum(kernel)
    return kernel


def mass_center_shift(kernel_size, kernel):
    """Calculate the shift of the mass center of a kenrel.
    Args:
        kernel_size (int):
        kernel (ndarray): normalized kernel.
    Returns:
        delta_h (float):
        delta_w (float):
    """
    ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
    col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
    delta_h = np.dot(row_sum, ax)
    delta_w = np.dot(col_sum, ax)
    return delta_h, delta_w


def bivariate_skew_Gaussian_center(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   D,
                                   grid=None):
    """Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        D (ndarrasy): skew matrix.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): centered and normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_anisotropic_Gaussian(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   grid=None):
    """Generate a bivariate anisotropic Gaussian kernel.
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
    """Generate a bivariate isotropic Gaussian kernel.
    Args:
        kernel_size (int):
        sig (float):
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
    kernel = pdf2(sigma_matrix, grid)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_generalized_Gaussian(kernel_size,
                                   sig_x,
                                   sig_y,
                                   theta,
                                   beta,
                                   grid=None):
    """Generate a bivariate generalized Gaussian kernel.
        Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
        by Pascal et. al (2013).
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    .. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
        https://arxiv.org/abs/1302.6498
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.exp(
        -0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
    """Generate a plateau-like anisotropic kernel.
    1 / (1+x^(beta))
    Args:
        kernel_size (int):
        sig_x (float):
        sig_y (float):
        theta (float): Radian measurement.
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.reciprocal(
        np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
    kernel = kernel / np.sum(kernel)
    return kernel


def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
    """Generate a plateau-like isotropic kernel.
    1 / (1+x^(beta))
    Args:
        kernel_size (int):
        sig (float):
        beta (float): shape parameter, beta = 1 is the normal distribution.
        grid (ndarray, optional): generated by :func:`mesh_grid`,
            with the shape (K, K, 2), K is the kernel size. Default: None
    Returns:
        kernel (ndarray): normalized kernel.
    """
    if grid is None:
        grid, _, _ = mesh_grid(kernel_size)
    sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
    inverse_sigma = np.linalg.inv(sigma_matrix)
    kernel = np.reciprocal(
        np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
    kernel = kernel / np.sum(kernel)
    return kernel


def random_bivariate_skew_Gaussian_center(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate skew Gaussian kernels at center.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])

    sigma_max = np.max([sigma_x, sigma_y])
    thres = 3 / sigma_max
    D = [[np.random.uniform(-thres, thres),
          np.random.uniform(-thres, thres)],
         [np.random.uniform(-thres, thres),
          np.random.uniform(-thres, thres)]]

    kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
                                            rotation, D)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, D
    else:
        return kernel


def random_bivariate_anisotropic_Gaussian(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate anisotropic Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])

    kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
                                            rotation)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation
    else:
        return kernel


def random_bivariate_isotropic_Gaussian(kernel_size,
                                        sigma_range,
                                        noise_range=None,
                                        strict=False):
    """Randomly generate bivariate isotropic Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_range (tuple): [0.6, 5]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])

    kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma
    else:
        return kernel


def random_bivariate_generalized_Gaussian(kernel_size,
                                          sigma_x_range,
                                          sigma_y_range,
                                          rotation_range,
                                          beta_range,
                                          noise_range=None,
                                          strict=False):
    """Randomly generate bivariate generalized Gaussian kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
                                            rotation, beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, beta
    else:
        return kernel


def random_bivariate_plateau_type1(kernel_size,
                                   sigma_x_range,
                                   sigma_y_range,
                                   rotation_range,
                                   beta_range,
                                   noise_range=None,
                                   strict=False):
    """Randomly generate bivariate plateau type1 kernels.
    Args:
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi/2, math.pi/2]
        beta_range (tuple): [1, 4]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
    assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
    assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
    sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
    sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
    if strict:
        sigma_max = np.max([sigma_x, sigma_y])
        sigma_min = np.min([sigma_x, sigma_y])
        sigma_x, sigma_y = sigma_max, sigma_min
    rotation = np.random.uniform(rotation_range[0], rotation_range[1])
    if np.random.uniform() < 0.5:
        beta = np.random.uniform(beta_range[0], 1)
    else:
        beta = np.random.uniform(1, beta_range[1])

    kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
                                     beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma_x, sigma_y, rotation, beta
    else:
        return kernel


def random_bivariate_plateau_type1_iso(kernel_size,
                                       sigma_range,
                                       beta_range,
                                       noise_range=None,
                                       strict=False):
    """Randomly generate bivariate plateau type1 kernels (iso).
    Args:
        kernel_size (int):
        sigma_range (tuple): [0.6, 5]
        beta_range (tuple): [1, 4]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
    assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
    sigma = np.random.uniform(sigma_range[0], sigma_range[1])
    beta = np.random.uniform(beta_range[0], beta_range[1])

    kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)

    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    if strict:
        return kernel, sigma, beta
    else:
        return kernel


def random_mixed_kernels(kernel_list,
                         kernel_prob,
                         kernel_size=21,
                         sigma_x_range=[0.6, 5],
                         sigma_y_range=[0.6, 5],
                         rotation_range=[-math.pi, math.pi],
                         beta_range=[0.5, 8],
                         noise_range=None):
    """Randomly generate mixed kernels.
    Args:
        kernel_list (tuple): a list name of kenrel types,
            support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
        kernel_prob (tuple): corresponding kernel probability for each kernel type
        kernel_size (int):
        sigma_x_range (tuple): [0.6, 5]
        sigma_y_range (tuple): [0.6, 5]
        rotation range (tuple): [-math.pi, math.pi]
        beta_range (tuple): [0.5, 8]
        noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
    Returns:
        kernel (ndarray):
    """
    kernel_type = random.choices(kernel_list, kernel_prob)[0]
    if kernel_type == 'iso':
        kernel = random_bivariate_isotropic_Gaussian(
            kernel_size, sigma_x_range, noise_range=noise_range)
    elif kernel_type == 'aniso':
        kernel = random_bivariate_anisotropic_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            noise_range=noise_range)
    elif kernel_type == 'skew':
        kernel = random_bivariate_skew_Gaussian_center(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            noise_range=noise_range)
    elif kernel_type == 'generalized':
        kernel = random_bivariate_generalized_Gaussian(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            beta_range,
            noise_range=noise_range)
    elif kernel_type == 'plateau_iso':
        kernel = random_bivariate_plateau_type1_iso(
            kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
    elif kernel_type == 'plateau_aniso':
        kernel = random_bivariate_plateau_type1(
            kernel_size,
            sigma_x_range,
            sigma_y_range,
            rotation_range,
            beta_range,
            noise_range=noise_range)
    # add multiplicative noise
    if noise_range is not None:
        assert noise_range[0] < noise_range[1], 'Wrong noise range.'
        noise = np.random.uniform(
            noise_range[0], noise_range[1], size=kernel.shape)
        kernel = kernel * noise
    kernel = kernel / np.sum(kernel)
    return kernel


def show_one_kernel():
    import matplotlib.pyplot as plt
    kernel_size = 21

    # bivariate skew Gaussian
    D = [[0, 0], [0, 0]]
    D = [[3 / 4, 0], [0, 0.5]]
    kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
    # bivariate anisotropic Gaussian
    kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
    # bivariate anisotropic Gaussian
    kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
    # bivariate generalized Gaussian
    kernel = bivariate_generalized_Gaussian(
        kernel_size, 2, 4, -math.pi / 4, beta=4)

    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    print(delta_h, delta_w)

    fig, axs = plt.subplots(nrows=2, ncols=2)
    # axs.set_axis_off()
    ax = axs[0][0]
    im = ax.matshow(kernel, cmap='jet', origin='upper')
    fig.colorbar(im, ax=ax)

    # image
    ax = axs[0][1]
    kernel_vis = kernel - np.min(kernel)
    kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
    ax.imshow(kernel_vis, interpolation='nearest')

    _, xx, yy = mesh_grid(kernel_size)
    # contour
    ax = axs[1][0]
    CS = ax.contour(xx, yy, kernel, origin='upper')
    ax.clabel(CS, inline=1, fontsize=3)

    # contourf
    ax = axs[1][1]
    kernel = kernel / np.max(kernel)
    p = ax.contourf(
        xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
    fig.colorbar(p)

    plt.show()


def show_plateau_kernel():
    import matplotlib.pyplot as plt
    kernel_size = 21

    kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
    kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
    kernel_gau = bivariate_generalized_Gaussian(
        kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
    delta_h, delta_w = mass_center_shift(kernel_size, kernel)
    print(delta_h, delta_w)

    # kernel_slice = kernel[10, :]
    # kernel_gau_slice = kernel_gau[10, :]
    # kernel_norm_slice = kernel_norm[10, :]
    # fig, ax = plt.subplots()
    # t = list(range(1, 22))

    # ax.plot(t, kernel_gau_slice)
    # ax.plot(t, kernel_slice)
    # ax.plot(t, kernel_norm_slice)

    # t = np.arange(0, 10, 0.1)
    # y = np.exp(-0.5 * t)
    # y2 = np.reciprocal(1 + t)
    # print(t.shape)
    # print(y.shape)
    # ax.plot(t, y)
    # ax.plot(t, y2)
    # plt.show()

    fig, axs = plt.subplots(nrows=2, ncols=2)
    # axs.set_axis_off()
    ax = axs[0][0]
    im = ax.matshow(kernel, cmap='jet', origin='upper')
    fig.colorbar(im, ax=ax)

    # image
    ax = axs[0][1]
    kernel_vis = kernel - np.min(kernel)
    kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
    ax.imshow(kernel_vis, interpolation='nearest')

    _, xx, yy = mesh_grid(kernel_size)
    # contour
    ax = axs[1][0]
    CS = ax.contour(xx, yy, kernel, origin='upper')
    ax.clabel(CS, inline=1, fontsize=3)

    # contourf
    ax = axs[1][1]
    kernel = kernel / np.max(kernel)
    p = ax.contourf(
        xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
    fig.colorbar(p)

    plt.show()


================================================
FILE: basicsr/data/paired_image_dataset.py
================================================
from torch.utils import data as data
from torchvision.transforms.functional import normalize

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY


@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
    """Paired image dataset for image restoration.

    Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
    GT image pairs.

    There are three modes:
    1. 'lmdb': Use lmdb files.
        If opt['io_backend'] == lmdb.
    2. 'meta_info_file': Use meta information file to generate paths.
        If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
    3. 'folder': Scan folders to generate paths.
        The rest.

    Args:
        opt (dict): Config for train datasets. It contains the following keys:
            dataroot_gt (str): Data root path for gt.
            dataroot_lq (str): Data root path for lq.
            meta_info_file (str): Path for meta information file.
            io_backend (dict): IO backend type and other kwarg.
            filename_tmpl (str): Template for each filename. Note that the
                template excludes the file extension. Default: '{}'.
            gt_size (int): Cropped patched size for gt patches.
            use_flip (bool): Use horizontal flips.
            use_rot (bool): Use rotation (use vertical flip and transposing h
                and w for implementation).

            scale (bool): Scale, which will be added automatically.
            phase (str): 'train' or 'val'.
    """

    def __init__(self, opt):
        super(PairedImageDataset, self).__init__()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.mean = opt['mean'] if 'mean' in opt else None
        self.std = opt['std'] if 'std' in opt else None

        self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
        if 'filename_tmpl' in opt:
            self.filename_tmpl = opt['filename_tmpl']
        else:
            self.filename_tmpl = '{}'

        if self.io_backend_opt['type'] == 'lmdb':
            self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
            self.io_backend_opt['client_keys'] = ['lq', 'gt']
            self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
        elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
            self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
                                                          self.opt['meta_info_file'], self.filename_tmpl)
        else:
            self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)

    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        scale = self.opt['scale']

        # Load gt and lq images. Dimension order: HWC; channel order: BGR;
        # image range: [0, 1], float32.
        gt_path = self.paths[index]['gt_path']
        img_bytes = self.file_client.get(gt_path, 'gt')
        img_gt = imfrombytes(img_bytes, float32=True)
        lq_path = self.paths[index]['lq_path']
        img_bytes = self.file_client.get(lq_path, 'lq')
        img_lq = imfrombytes(img_bytes, float32=True)

        # augmentation for training
        if self.opt['phase'] == 'train':
            gt_size = self.opt['gt_size']
            # random crop
            img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
            # flip, rotation
            img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])

        # TODO: color space transform
        # BGR to RGB, HWC to CHW, numpy to tensor
        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
        # normalize
        if self.mean is not None or self.std is not None:
            normalize(img_lq, self.mean, self.std, inplace=True)
            normalize(img_gt, self.mean, self.std, inplace=True)

        return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}

    def __len__(self):
        return len(self.paths)


================================================
FILE: basicsr/data/prefetch_dataloader.py
================================================
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader


class PrefetchGenerator(threading.Thread):
    """A general prefetch generator.

    Ref:
    https://stackoverflow.com/questions/7323664/python-generator-pre-fetch

    Args:
        generator: Python generator.
        num_prefetch_queue (int): Number of prefetch queue.
    """

    def __init__(self, generator, num_prefetch_queue):
        threading.Thread.__init__(self)
        self.queue = Queue.Queue(num_prefetch_queue)
        self.generator = generator
        self.daemon = True
        self.start()

    def run(self):
        for item in self.generator:
            self.queue.put(item)
        self.queue.put(None)

    def __next__(self):
        next_item = self.queue.get()
        if next_item is None:
            raise StopIteration
        return next_item

    def __iter__(self):
        return self


class PrefetchDataLoader(DataLoader):
    """Prefetch version of dataloader.

    Ref:
    https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#

    TODO:
    Need to test on single gpu and ddp (multi-gpu). There is a known issue in
    ddp.

    Args:
        num_prefetch_queue (int): Number of prefetch queue.
        kwargs (dict): Other arguments for dataloader.
    """

    def __init__(self, num_prefetch_queue, **kwargs):
        self.num_prefetch_queue = num_prefetch_queue
        super(PrefetchDataLoader, self).__init__(**kwargs)

    def __iter__(self):
        return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)


class CPUPrefetcher():
    """CPU prefetcher.

    Args:
        loader: Dataloader.
    """

    def __init__(self, loader):
        self.ori_loader = loader
        self.loader = iter(loader)

    def next(self):
        try:
            return next(self.loader)
        except StopIteration:
            return None

    def reset(self):
        self.loader = iter(self.ori_loader)


class CUDAPrefetcher():
    """CUDA prefetcher.

    Ref:
    https://github.com/NVIDIA/apex/issues/304#

    It may consums more GPU memory.

    Args:
        loader: Dataloader.
        opt (dict): Options.
    """

    def __init__(self, loader, opt):
        self.ori_loader = loader
        self.loader = iter(loader)
        self.opt = opt
        self.stream = torch.cuda.Stream()
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
        self.preload()

    def preload(self):
        try:
            self.batch = next(self.loader)  # self.batch is a dict
        except StopIteration:
            self.batch = None
            return None
        # put tensors to gpu
        with torch.cuda.stream(self.stream):
            for k, v in self.batch.items():
                if torch.is_tensor(v):
                    self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        batch = self.batch
        self.preload()
        return batch

    def reset(self):
        self.loader = iter(self.ori_loader)
        self.preload()


================================================
FILE: basicsr/data/transforms.py
================================================
import cv2
import random


def mod_crop(img, scale):
    """Mod crop images, used during testing.

    Args:
        img (ndarray): Input image.
        scale (int): Scale factor.

    Returns:
        ndarray: Result image.
    """
    img = img.copy()
    if img.ndim in (2, 3):
        h, w = img.shape[0], img.shape[1]
        h_remainder, w_remainder = h % scale, w % scale
        img = img[:h - h_remainder, :w - w_remainder, ...]
    else:
        raise ValueError(f'Wrong img ndim: {img.ndim}.')
    return img


def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
    """Paired random crop.

    It crops lists of lq and gt images with corresponding locations.

    Args:
        img_gts (list[ndarray] | ndarray): GT images. Note that all images
            should have the same shape. If the input is an ndarray, it will
            be transformed to a list containing itself.
        img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
            should have the same shape. If the input is an ndarray, it will
            be transformed to a list containing itself.
        gt_patch_size (int): GT patch size.
        scale (int): Scale factor.
        gt_path (str): Path to ground-truth.

    Returns:
        list[ndarray] | ndarray: GT images and LQ images. If returned results
            only have one element, just return ndarray.
    """

    if not isinstance(img_gts, list):
        img_gts = [img_gts]
    if not isinstance(img_lqs, list):
        img_lqs = [img_lqs]

    h_lq, w_lq, _ = img_lqs[0].shape
    h_gt, w_gt, _ = img_gts[0].shape
    lq_patch_size = gt_patch_size // scale

    if h_gt != h_lq * scale or w_gt != w_lq * scale:
        raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
                         f'multiplication of LQ ({h_lq}, {w_lq}).')
    if h_lq < lq_patch_size or w_lq < lq_patch_size:
        raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
                         f'({lq_patch_size}, {lq_patch_size}). '
                         f'Please remove {gt_path}.')

    # randomly choose top and left coordinates for lq patch
    top = random.randint(0, h_lq - lq_patch_size)
    left = random.randint(0, w_lq - lq_patch_size)

    # crop lq patch
    img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]

    # crop corresponding gt patch
    top_gt, left_gt = int(top * scale), int(left * scale)
    img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
    if len(img_gts) == 1:
        img_gts = img_gts[0]
    if len(img_lqs) == 1:
        img_lqs = img_lqs[0]
    return img_gts, img_lqs


def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
    """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).

    We use vertical flip and transpose for rotation implementation.
    All the images in the list use the same augmentation.

    Args:
        imgs (list[ndarray] | ndarray): Images to be augmented. If the input
            is an ndarray, it will be transformed to a list.
        hflip (bool): Horizontal flip. Default: True.
        rotation (bool): Ratotation. Default: True.
        flows (list[ndarray]: Flows to be augmented. If the input is an
            ndarray, it will be transformed to a list.
            Dimension is (h, w, 2). Default: None.
        return_status (bool): Return the status of flip and rotation.
            Default: False.

    Returns:
        list[ndarray] | ndarray: Augmented images and flows. If returned
            results only have one element, just return ndarray.

    """
    hflip = hflip and random.random() < 0.5
    vflip = rotation and random.random() < 0.5
    rot90 = rotation and random.random() < 0.5

    def _augment(img):
        if hflip:  # horizontal
            cv2.flip(img, 1, img)
        if vflip:  # vertical
            cv2.flip(img, 0, img)
        if rot90:
            img = img.transpose(1, 0, 2)
        return img

    def _augment_flow(flow):
        if hflip:  # horizontal
            cv2.flip(flow, 1, flow)
            flow[:, :, 0] *= -1
        if vflip:  # vertical
            cv2.flip(flow, 0, flow)
            flow[:, :, 1] *= -1
        if rot90:
            flow = flow.transpose(1, 0, 2)
            flow = flow[:, :, [1, 0]]
        return flow

    if not isinstance(imgs, list):
        imgs = [imgs]
    imgs = [_augment(img) for img in imgs]
    if len(imgs) == 1:
        imgs = imgs[0]

    if flows is not None:
        if not isinstance(flows, list):
            flows = [flows]
        flows = [_augment_flow(flow) for flow in flows]
        if len(flows) == 1:
            flows = flows[0]
        return imgs, flows
    else:
        if return_status:
            return imgs, (hflip, vflip, rot90)
        else:
            return imgs


def img_rotate(img, angle, center=None, scale=1.0):
    """Rotate image.

    Args:
        img (ndarray): Image to be rotated.
        angle (float): Rotation angle in degrees. Positive values mean
            counter-clockwise rotation.
        center (tuple[int]): Rotation center. If the center is None,
            initialize it as the center of the image. Default: None.
        scale (float): Isotropic scale factor. Default: 1.0.
    """
    (h, w) = img.shape[:2]

    if center is None:
        center = (w // 2, h // 2)

    matrix = cv2.getRotationMatrix2D(center, angle, scale)
    rotated_img = cv2.warpAffine(img, matrix, (w, h))
    return rotated_img


================================================
FILE: basicsr/losses/__init__.py
================================================
from copy import deepcopy

from basicsr.utils import get_root_logger
from basicsr.utils.registry import LOSS_REGISTRY
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
                     gradient_penalty_loss, r1_penalty)

__all__ = [
    'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
    'r1_penalty', 'g_path_regularize'
]


def build_loss(opt):
    """Build loss from options.

    Args:
        opt (dict): Configuration. It must constain:
            type (str): Model type.
    """
    opt = deepcopy(opt)
    loss_type = opt.pop('type')
    loss = LOSS_REGISTRY.get(loss_type)(**opt)
    logger = get_root_logger()
    logger.info(f'Loss [{loss.__class__.__name__}] is created.')
    return loss


================================================
FILE: basicsr/losses/loss_util.py
================================================
import functools
from torch.nn import functional as F


def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are 'none', 'mean' and 'sum'.

    Returns:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    else:
        return loss.sum()


def weight_reduce_loss(loss, weight=None, reduction='mean'):
    """Apply element-wise weight and reduce loss.

    Args:
        loss (Tensor): Element-wise loss.
        weight (Tensor): Element-wise weights. Default: None.
        reduction (str): Same as built-in losses of PyTorch. Options are
            'none', 'mean' and 'sum'. Default: 'mean'.

    Returns:
        Tensor: Loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        assert weight.dim() == loss.dim()
        assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
        loss = loss * weight

    # if weight is not specified or reduction is sum, just reduce the loss
    if weight is None or reduction == 'sum':
        loss = reduce_loss(loss, reduction)
    # if reduction is mean, then compute mean over weight region
    elif reduction == 'mean':
        if weight.size(1) > 1:
            weight = weight.sum()
        else:
            weight = weight.sum() * loss.size(1)
        loss = loss.sum() / weight

    return loss


def weighted_loss(loss_func):
    """Create a weighted version of a given loss function.

    To use this decorator, the loss function must have the signature like
    `loss_func(pred, target, **kwargs)`. The function only needs to compute
    element-wise loss without any reduction. This decorator will add weight
    and reduction arguments to the function. The decorated function will have
    the signature like `loss_func(pred, target, weight=None, reduction='mean',
    **kwargs)`.

    :Example:

    >>> import torch
    >>> @weighted_loss
    >>> def l1_loss(pred, target):
    >>>     return (pred - target).abs()

    >>> pred = torch.Tensor([0, 2, 3])
    >>> target = torch.Tensor([1, 1, 1])
    >>> weight = torch.Tensor([1, 0, 1])

    >>> l1_loss(pred, target)
    tensor(1.3333)
    >>> l1_loss(pred, target, weight)
    tensor(1.5000)
    >>> l1_loss(pred, target, reduction='none')
    tensor([1., 1., 2.])
    >>> l1_loss(pred, target, weight, reduction='sum')
    tensor(3.)
    """

    @functools.wraps(loss_func)
    def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
        # get element-wise loss
        loss = loss_func(pred, target, **kwargs)
        loss = weight_reduce_loss(loss, weight, reduction)
        return loss

    return wrapper


================================================
FILE: basicsr/losses/losses.py
================================================
import math
import lpips
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F

from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss

_reduction_modes = ['none', 'mean', 'sum']


@weighted_loss
def l1_loss(pred, target):
    return F.l1_loss(pred, target, reduction='none')


@weighted_loss
def mse_loss(pred, target):
    return F.mse_loss(pred, target, reduction='none')


@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
    return torch.sqrt((pred - target)**2 + eps)


@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
    """L1 (mean absolute error, MAE) loss.

    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(L1Loss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)


@LOSS_REGISTRY.register()
class MSELoss(nn.Module):
    """MSE (L2) loss.

    Args:
        loss_weight (float): Loss weight for MSE loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
    """

    def __init__(self, loss_weight=1.0, reduction='mean'):
        super(MSELoss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)


@LOSS_REGISTRY.register()
class CharbonnierLoss(nn.Module):
    """Charbonnier loss (one variant of Robust L1Loss, a differentiable
    variant of L1Loss).

    Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
        Super-Resolution".

    Args:
        loss_weight (float): Loss weight for L1 loss. Default: 1.0.
        reduction (str): Specifies the reduction to apply to the output.
            Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
        eps (float): A value used to control the curvature near zero.
            Default: 1e-12.
    """

    def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
        super(CharbonnierLoss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.reduction = reduction
        self.eps = eps

    def forward(self, pred, target, weight=None, **kwargs):
        """
        Args:
            pred (Tensor): of shape (N, C, H, W). Predicted tensor.
            target (Tensor): of shape (N, C, H, W). Ground truth tensor.
            weight (Tensor, optional): of shape (N, C, H, W). Element-wise
                weights. Default: None.
        """
        return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)


@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
    """Weighted TV loss.

        Args:
            loss_weight (float): Loss weight. Default: 1.0.
    """

    def __init__(self, loss_weight=1.0):
        super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)

    def forward(self, pred, weight=None):
        y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
        x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])

        loss = x_diff + y_diff

        return loss


@LOSS_REGISTRY.register()
class PerceptualLoss(nn.Module):
    """Perceptual loss with commonly used style loss.

    Args:
        layer_weights (dict): The weight for each layer of vgg feature.
            Here is an example: {'conv5_4': 1.}, which means the conv5_4
            feature layer (before relu5_4) will be extracted with weight
            1.0 in calculting losses.
        vgg_type (str): The type of vgg network used as feature extractor.
            Default: 'vgg19'.
        use_input_norm (bool):  If True, normalize the input image in vgg.
            Default: True.
        range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
            Default: False.
        perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
            loss will be calculated and the loss will multiplied by the
            weight. Default: 1.0.
        style_weight (float): If `style_weight > 0`, the style loss will be
            calculated and the loss will multiplied by the weight.
            Default: 0.
        criterion (str): Criterion used for perceptual loss. Default: 'l1'.
    """

    def __init__(self,
                 layer_weights,
                 vgg_type='vgg19',
                 use_input_norm=True,
                 range_norm=False,
                 perceptual_weight=1.0,
                 style_weight=0.,
                 criterion='l1'):
        super(PerceptualLoss, self).__init__()
        self.perceptual_weight = perceptual_weight
        self.style_weight = style_weight
        self.layer_weights = layer_weights
        self.vgg = VGGFeatureExtractor(
            layer_name_list=list(layer_weights.keys()),
            vgg_type=vgg_type,
            use_input_norm=use_input_norm,
            range_norm=range_norm)

        self.criterion_type = criterion
        if self.criterion_type == 'l1':
            self.criterion = torch.nn.L1Loss()
        elif self.criterion_type == 'l2':
            self.criterion = torch.nn.L2loss()
        elif self.criterion_type == 'mse':
            self.criterion = torch.nn.MSELoss(reduction='mean')
        elif self.criterion_type == 'fro':
            self.criterion = None
        else:
            raise NotImplementedError(f'{criterion} criterion has not been supported.')

    def forward(self, x, gt):
        """Forward function.

        Args:
            x (Tensor): Input tensor with shape (n, c, h, w).
            gt (Tensor): Ground-truth tensor with shape (n, c, h, w).

        Returns:
            Tensor: Forward results.
        """
        # extract vgg features
        x_features = self.vgg(x)
        gt_features = self.vgg(gt.detach())

        # calculate perceptual loss
        if self.perceptual_weight > 0:
            percep_loss = 0
            for k in x_features.keys():
                if self.criterion_type == 'fro':
                    percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
                else:
                    percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
            percep_loss *= self.perceptual_weight
        else:
            percep_loss = None

        # calculate style loss
        if self.style_weight > 0:
            style_loss = 0
            for k in x_features.keys():
                if self.criterion_type == 'fro':
                    style_loss += torch.norm(
                        self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
                else:
                    style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
                        gt_features[k])) * self.layer_weights[k]
            style_loss *= self.style_weight
        else:
            style_loss = None

        return percep_loss, style_loss

    def _gram_mat(self, x):
        """Calculate Gram matrix.

        Args:
            x (torch.Tensor): Tensor with shape of (n, c, h, w).

        Returns:
            torch.Tensor: Gram matrix.
        """
        n, c, h, w = x.size()
        features = x.view(n, c, w * h)
        features_t = features.transpose(1, 2)
        gram = features.bmm(features_t) / (c * h * w)
        return gram


@LOSS_REGISTRY.register()
class LPIPSLoss(nn.Module):
    def __init__(self, 
            loss_weight=1.0, 
            use_input_norm=True,
            range_norm=False,):
        super(LPIPSLoss, self).__init__()
        self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
        self.loss_weight = loss_weight
        self.use_input_norm = use_input_norm
        self.range_norm = range_norm

        if self.use_input_norm:
            # the mean is for image with range [0, 1]
            self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            # the std is for image with range [0, 1]
            self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, pred, target):
        if self.range_norm:
            pred   = (pred + 1) / 2
            target = (target + 1) / 2
        if self.use_input_norm:
            pred   = (pred - self.mean) / self.std
            target = (target - self.mean) / self.std
        lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
        return self.loss_weight * lpips_loss.mean()


@LOSS_REGISTRY.register()
class GANLoss(nn.Module):
    """Define GAN loss.

    Args:
        gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
        real_label_val (float): The value for real label. Default: 1.0.
        fake_label_val (float): The value for fake label. Default: 0.0.
        loss_weight (float): Loss weight. Default: 1.0.
            Note that loss_weight is only for generators; and it is always 1.0
            for discriminators.
    """

    def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
        super(GANLoss, self).__init__()
        self.gan_type = gan_type
        self.loss_weight = loss_weight
        self.real_label_val = real_label_val
        self.fake_label_val = fake_label_val

        if self.gan_type == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.gan_type == 'wgan':
            self.loss = self._wgan_loss
        elif self.gan_type == 'wgan_softplus':
            self.loss = self._wgan_softplus_loss
        elif self.gan_type == 'hinge':
            self.loss = nn.ReLU()
        else:
            raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')

    def _wgan_loss(self, input, target):
        """wgan loss.

        Args:
            input (Tensor): Input tensor.
            target (bool): Target label.

        Returns:
            Tensor: wgan loss.
        """
        return -input.mean() if target else input.mean()

    def _wgan_softplus_loss(self, input, target):
        """wgan loss with soft plus. softplus is a smooth approximation to the
        ReLU function.

        In StyleGAN2, it is called:
            Logistic loss for discriminator;
            Non-saturating loss for generator.

        Args:
            input (Tensor): Input tensor.
            target (bool): Target label.

        Returns:
            Tensor: wgan loss.
        """
        return F.softplus(-input).mean() if target else F.softplus(input).mean()

    def get_target_label(self, input, target_is_real):
        """Get target label.

        Args:
            input (Tensor): Input tensor.
            target_is_real (bool): Whether the target is real or fake.

        Returns:
            (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
                return Tensor.
        """

        if self.gan_type in ['wgan', 'wgan_softplus']:
            return target_is_real
        target_val = (self.real_label_val if target_is_real else self.fake_label_val)
        return input.new_ones(input.size()) * target_val

    def forward(self, input, target_is_real, is_disc=False):
        """
        Args:
            input (Tensor): The input for the loss module, i.e., the network
                prediction.
            target_is_real (bool): Whether the targe is real or fake.
            is_disc (bool): Whether the loss for discriminators or not.
                Default: False.

        Returns:
            Tensor: GAN loss value.
        """
        if self.gan_type == 'hinge':
            if is_disc:  # for discriminators in hinge-gan
                input = -input if target_is_real else input
                loss = self.loss(1 + input).mean()
            else:  # for generators in hinge-gan
                loss = -input.mean()
        else:  # other gan types
            target_label = self.get_target_label(input, target_is_real)
            loss = self.loss(input, target_label)

        # loss_weight is always 1.0 for discriminators
        return loss if is_disc else loss * self.loss_weight


def r1_penalty(real_pred, real_img):
    """R1 regularization for discriminator. The core idea is to
        penalize the gradient on real data alone: when the
        generator distribution produces the true data distribution
        and the discriminator is equal to 0 on the data manifold, the
        gradient penalty ensures that the discriminator cannot create
        a non-zero gradient orthogonal to the data manifold without
        suffering a loss in the GAN game.

        Ref:
        Eq. 9 in Which training methods for GANs do actually converge.
        """
    grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
    grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
    return grad_penalty


def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
    noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
    grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
    path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

    path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)

    path_penalty = (path_lengths - path_mean).pow(2).mean()

    return path_penalty, path_lengths.detach().mean(), path_mean.detach()


def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
    """Calculate gradient penalty for wgan-gp.

    Args:
        discriminator (nn.Module): Network for the discriminator.
        real_data (Tensor): Real input data.
        fake_data (Tensor): Fake input data.
        weight (Tensor): Weight tensor. Default: None.

    Returns:
        Tensor: A tensor for gradient penalty.
    """

    batch_size = real_data.size(0)
    alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))

    # interpolate between real_data and fake_data
    interpolates = alpha * real_data + (1. - alpha) * fake_data
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = discriminator(interpolates)
    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(disc_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0]

    if weight is not None:
        gradients = gradients * weight

    gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
    if weight is not None:
        gradients_penalty /= torch.mean(weight)

    return gradients_penalty


================================================
FILE: basicsr/metrics/__init__.py
================================================
from copy import deepcopy

from basicsr.utils.registry import METRIC_REGISTRY
from .psnr_ssim import calculate_psnr, calculate_ssim

__all__ = ['calculate_psnr', 'calculate_ssim']


def calculate_metric(data, opt):
    """Calculate metric from data and options.

    Args:
        opt (dict): Configuration. It must constain:
            type (str): Model type.
    """
    opt = deepcopy(opt)
    metric_type = opt.pop('type')
    metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
    return metric


================================================
FILE: basicsr/metrics/metric_util.py
================================================
import numpy as np

from basicsr.utils.matlab_functions import bgr2ycbcr


def reorder_image(img, input_order='HWC'):
    """Reorder images to 'HWC' order.

    If the input_order is (h, w), return (h, w, 1);
    If the input_order is (c, h, w), return (h, w, c);
    If the input_order is (h, w, c), return as it is.

    Args:
        img (ndarray): Input image.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            If the input image shape is (h, w), input_order will not have
            effects. Default: 'HWC'.

    Returns:
        ndarray: reordered image.
    """

    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
    if len(img.shape) == 2:
        img = img[..., None]
    if input_order == 'CHW':
        img = img.transpose(1, 2, 0)
    return img


def to_y_channel(img):
    """Change to Y channel of YCbCr.

    Args:
        img (ndarray): Images with range [0, 255].

    Returns:
        (ndarray): Images with range [0, 255] (float type) without round.
    """
    img = img.astype(np.float32) / 255.
    if img.ndim == 3 and img.shape[2] == 3:
        img = bgr2ycbcr(img, y_only=True)
        img = img[..., None]
    return img * 255.


================================================
FILE: basicsr/metrics/psnr_ssim.py
================================================
import cv2
import numpy as np

from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.registry import METRIC_REGISTRY


@METRIC_REGISTRY.register()
def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate PSNR (Peak Signal-to-Noise Ratio).

    Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the PSNR calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: psnr result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20. * np.log10(255. / np.sqrt(mse))


def _ssim(img1, img2):
    """Calculate SSIM (structural similarity) for one channel images.

    It is called by func:`calculate_ssim`.

    Args:
        img1 (ndarray): Images with range [0, 255] with order 'HWC'.
        img2 (ndarray): Images with range [0, 255] with order 'HWC'.

    Returns:
        float: ssim result.
    """

    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


@METRIC_REGISTRY.register()
def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
    """Calculate SSIM (structural similarity).

    Ref:
    Image quality assessment: From error visibility to structural similarity

    The results are the same as that of the official released MATLAB code in
    https://ece.uwaterloo.ca/~z70wang/research/ssim/.

    For three-channel images, SSIM is calculated for each channel and then
    averaged.

    Args:
        img1 (ndarray): Images with range [0, 255].
        img2 (ndarray): Images with range [0, 255].
        crop_border (int): Cropped pixels in each edge of an image. These
            pixels are not involved in the SSIM calculation.
        input_order (str): Whether the input order is 'HWC' or 'CHW'.
            Default: 'HWC'.
        test_y_channel (bool): Test on Y channel of YCbCr. Default: False.

    Returns:
        float: ssim result.
    """

    assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
    if input_order not in ['HWC', 'CHW']:
        raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
    img1 = reorder_image(img1, input_order=input_order)
    img2 = reorder_image(img2, input_order=input_order)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)

    if crop_border != 0:
        img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
        img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]

    if test_y_channel:
        img1 = to_y_channel(img1)
        img2 = to_y_channel(img2)

    ssims = []
    for i in range(img1.shape[2]):
        ssims.append(_ssim(img1[..., i], img2[..., i]))
    return np.array(ssims).mean()


================================================
FILE: basicsr/models/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp

from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import MODEL_REGISTRY

__all__ = ['build_model']

# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]


def build_model(opt):
    """Build model from options.

    Args:
        opt (dict): Configuration. It must constain:
            model_type (str): Model type.
    """
    opt = deepcopy(opt)
    model = MODEL_REGISTRY.get(opt['model_type'])(opt)
    logger = get_root_logger()
    logger.info(f'Model [{model.__class__.__name__}] is created.')
    return model


================================================
FILE: basicsr/models/base_model.py
================================================
import logging
import os
import torch
from collections import OrderedDict
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel

from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils.dist_util import master_only

logger = logging.getLogger('basicsr')


class BaseModel():
    """Base model."""

    def __init__(self, opt):
        self.opt = opt
        self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
        self.is_train = opt['is_train']
        self.schedulers = []
        self.optimizers = []

    def feed_data(self, data):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        pass

    def save(self, epoch, current_iter):
        """Save networks and training state."""
        pass

    def validation(self, dataloader, current_iter, tb_logger, save_img=False):
        """Validation function.

        Args:
            dataloader (torch.utils.data.DataLoader): Validation dataloader.
            current_iter (int): Current iteration.
            tb_logger (tensorboard logger): Tensorboard logger.
            save_img (bool): Whether to save images. Default: False.
        """
        if self.opt['dist']:
            self.dist_validation(dataloader, current_iter, tb_logger, save_img)
        else:
            self.nondist_validation(dataloader, current_iter, tb_logger, save_img)

    def model_ema(self, decay=0.999):
        net_g = self.get_bare_model(self.net_g)

        net_g_params = dict(net_g.named_parameters())
        net_g_ema_params = dict(self.net_g_ema.named_parameters())

        for k in net_g_ema_params.keys():
            net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)

    def get_current_log(self):
        return self.log_dict

    def model_to_device(self, net):
        """Model to device. It also warps models with DistributedDataParallel
        or DataParallel.

        Args:
            net (nn.Module)
        """
        net = net.to(self.device)
        if self.opt['dist']:
            find_unused_parameters = self.opt.get('find_unused_parameters', False)
            net = DistributedDataParallel(
                net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
        elif self.opt['num_gpu'] > 1:
            net = DataParallel(net)
        return net

    def get_optimizer(self, optim_type, params, lr, **kwargs):
        if optim_type == 'Adam':
            optimizer = torch.optim.Adam(params, lr, **kwargs)
        else:
            raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
        return optimizer

    def setup_schedulers(self):
        """Set up schedulers."""
        train_opt = self.opt['train']
        scheduler_type = train_opt['scheduler'].pop('type')
        if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
            for optimizer in self.optimizers:
                self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
        elif scheduler_type == 'CosineAnnealingRestartLR':
            for optimizer in self.optimizers:
                self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
        else:
            raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')

    def get_bare_model(self, net):
        """Get bare model, especially under wrapping with
        DistributedDataParallel or DataParallel.
        """
        if isinstance(net, (DataParallel, DistributedDataParallel)):
            net = net.module
        return net

    @master_only
    def print_network(self, net):
        """Print the str and parameter number of a network.

        Args:
            net (nn.Module)
        """
        if isinstance(net, (DataParallel, DistributedDataParallel)):
            net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
        else:
            net_cls_str = f'{net.__class__.__name__}'

        net = self.get_bare_model(net)
        net_str = str(net)
        net_params = sum(map(lambda x: x.numel(), net.parameters()))

        logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
        logger.info(net_str)

    def _set_lr(self, lr_groups_l):
        """Set learning rate for warmup.

        Args:
            lr_groups_l (list): List for lr_groups, each for an optimizer.
        """
        for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
            for param_group, lr in zip(optimizer.param_groups, lr_groups):
                param_group['lr'] = lr

    def _get_init_lr(self):
        """Get the initial lr, which is set by the scheduler.
        """
        init_lr_groups_l = []
        for optimizer in self.optimizers:
            init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
        return init_lr_groups_l

    def update_learning_rate(self, current_iter, warmup_iter=-1):
        """Update learning rate.

        Args:
            current_iter (int): Current iteration.
            warmup_iter (int): Warmup iter numbers. -1 for no warmup.
                Default: -1.
        """
        if current_iter > 1:
            for scheduler in self.schedulers:
                scheduler.step()
        # set up warm-up learning rate
        if current_iter < warmup_iter:
            # get initial lr for each group
            init_lr_g_l = self._get_init_lr()
            # modify warming-up learning rates
            # currently only support linearly warm up
            warm_up_lr_l = []
            for init_lr_g in init_lr_g_l:
                warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
            # set learning rate
            self._set_lr(warm_up_lr_l)

    def get_current_learning_rate(self):
        return [param_group['lr'] for param_group in self.optimizers[0].param_groups]

    @master_only
    def save_network(self, net, net_label, current_iter, param_key='params'):
        """Save networks.

        Args:
            net (nn.Module | list[nn.Module]): Network(s) to be saved.
            net_label (str): Network label.
            current_iter (int): Current iter number.
            param_key (str | list[str]): The parameter key(s) to save network.
                Default: 'params'.
        """
        if current_iter == -1:
            current_iter = 'latest'
        save_filename = f'{net_label}_{current_iter}.pth'
        save_path = os.path.join(self.opt['path']['models'], save_filename)

        net = net if isinstance(net, list) else [net]
        param_key = param_key if isinstance(param_key, list) else [param_key]
        assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'

        save_dict = {}
        for net_, param_key_ in zip(net, param_key):
            net_ = self.get_bare_model(net_)
            state_dict = net_.state_dict()
            for key, param in state_dict.items():
                if key.startswith('module.'):  # remove unnecessary 'module.'
                    key = key[7:]
                state_dict[key] = param.cpu()
            save_dict[param_key_] = state_dict

        torch.save(save_dict, save_path)

    def _print_different_keys_loading(self, crt_net, load_net, strict=True):
        """Print keys with differnet name or different size when loading models.

        1. Print keys with differnet names.
        2. If strict=False, print the same key but with different tensor size.
       
Download .txt
gitextract_zvvj76bx/

├── .gitignore
├── LICENSE
├── README.md
├── basicsr/
│   ├── VERSION
│   ├── __init__.py
│   ├── archs/
│   │   ├── __init__.py
│   │   ├── arcface_arch.py
│   │   ├── arch_util.py
│   │   ├── codeformer_arch.py
│   │   ├── rrdbnet_arch.py
│   │   ├── vgg_arch.py
│   │   └── vqgan_arch.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── data_sampler.py
│   │   ├── data_util.py
│   │   ├── ffhq_blind_dataset.py
│   │   ├── ffhq_blind_joint_dataset.py
│   │   ├── gaussian_kernels.py
│   │   ├── paired_image_dataset.py
│   │   ├── prefetch_dataloader.py
│   │   └── transforms.py
│   ├── losses/
│   │   ├── __init__.py
│   │   ├── loss_util.py
│   │   └── losses.py
│   ├── metrics/
│   │   ├── __init__.py
│   │   ├── metric_util.py
│   │   └── psnr_ssim.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── codeformer_idx_model.py
│   │   ├── codeformer_joint_model.py
│   │   ├── codeformer_model.py
│   │   ├── lr_scheduler.py
│   │   ├── sr_model.py
│   │   └── vqgan_model.py
│   ├── ops/
│   │   ├── __init__.py
│   │   ├── dcn/
│   │   │   ├── __init__.py
│   │   │   ├── deform_conv.py
│   │   │   └── src/
│   │   │       ├── deform_conv_cuda.cpp
│   │   │       ├── deform_conv_cuda_kernel.cu
│   │   │       └── deform_conv_ext.cpp
│   │   ├── fused_act/
│   │   │   ├── __init__.py
│   │   │   ├── fused_act.py
│   │   │   └── src/
│   │   │       ├── fused_bias_act.cpp
│   │   │       └── fused_bias_act_kernel.cu
│   │   └── upfirdn2d/
│   │       ├── __init__.py
│   │       ├── src/
│   │       │   ├── upfirdn2d.cpp
│   │       │   └── upfirdn2d_kernel.cu
│   │       └── upfirdn2d.py
│   ├── setup.py
│   ├── train.py
│   └── utils/
│       ├── __init__.py
│       ├── dist_util.py
│       ├── download_util.py
│       ├── file_client.py
│       ├── img_util.py
│       ├── lmdb_util.py
│       ├── logger.py
│       ├── matlab_functions.py
│       ├── misc.py
│       ├── options.py
│       ├── realesrgan_utils.py
│       ├── registry.py
│       └── video_util.py
├── docs/
│   ├── history_changelog.md
│   ├── train.md
│   └── train_CN.md
├── facelib/
│   ├── detection/
│   │   ├── __init__.py
│   │   ├── align_trans.py
│   │   ├── matlab_cp2tform.py
│   │   ├── retinaface/
│   │   │   ├── retinaface.py
│   │   │   ├── retinaface_net.py
│   │   │   └── retinaface_utils.py
│   │   └── yolov5face/
│   │       ├── __init__.py
│   │       ├── face_detector.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── common.py
│   │       │   ├── experimental.py
│   │       │   ├── yolo.py
│   │       │   ├── yolov5l.yaml
│   │       │   └── yolov5n.yaml
│   │       └── utils/
│   │           ├── __init__.py
│   │           ├── autoanchor.py
│   │           ├── datasets.py
│   │           ├── extract_ckpt.py
│   │           ├── general.py
│   │           └── torch_utils.py
│   ├── parsing/
│   │   ├── __init__.py
│   │   ├── bisenet.py
│   │   ├── parsenet.py
│   │   └── resnet.py
│   └── utils/
│       ├── __init__.py
│       ├── face_restoration_helper.py
│       ├── face_utils.py
│       └── misc.py
├── inference_codeformer.py
├── inference_colorization.py
├── inference_inpainting.py
├── options/
│   ├── CodeFormer_colorization.yml
│   ├── CodeFormer_inpainting.yml
│   ├── CodeFormer_stage2.yml
│   ├── CodeFormer_stage3.yml
│   └── VQGAN_512_ds32_nearest_stage1.yml
├── requirements.txt
├── scripts/
│   ├── crop_align_face.py
│   ├── download_pretrained_models.py
│   ├── download_pretrained_models_from_gdrive.py
│   ├── generate_latent_gt.py
│   └── inference_vqgan.py
├── web-demos/
│   ├── hugging_face/
│   │   └── app.py
│   └── replicate/
│       ├── cog.yaml
│       └── predict.py
└── weights/
    ├── CodeFormer/
    │   └── .gitkeep
    ├── README.md
    └── facelib/
        └── .gitkeep
Download .txt
SYMBOL INDEX (722 symbols across 78 files)

FILE: basicsr/archs/__init__.py
  function build_network (line 19) | def build_network(opt):

FILE: basicsr/archs/arcface_arch.py
  function conv3x3 (line 5) | def conv3x3(inplanes, outplanes, stride=1):
  class BasicBlock (line 16) | class BasicBlock(nn.Module):
    method __init__ (line 27) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 37) | def forward(self, x):
  class IRBlock (line 56) | class IRBlock(nn.Module):
    method __init__ (line 68) | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se...
    method forward (line 82) | def forward(self, x):
  class Bottleneck (line 103) | class Bottleneck(nn.Module):
    method __init__ (line 114) | def __init__(self, inplanes, planes, stride=1, downsample=None):
    method forward (line 126) | def forward(self, x):
  class SEBlock (line 149) | class SEBlock(nn.Module):
    method __init__ (line 157) | def __init__(self, channel, reduction=16):
    method forward (line 164) | def forward(self, x):
  class ResNetArcFace (line 172) | class ResNetArcFace(nn.Module):
    method __init__ (line 183) | def __init__(self, block, layers, use_se=True):
    method _make_layer (line 214) | def _make_layer(self, block, planes, num_blocks, stride=1):
    method forward (line 229) | def forward(self, x):

FILE: basicsr/archs/arch_util.py
  function default_init_weights (line 18) | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
  function make_layer (line 48) | def make_layer(basic_block, num_basic_block, **kwarg):
  class ResidualBlockNoBN (line 64) | class ResidualBlockNoBN(nn.Module):
    method __init__ (line 79) | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
    method forward (line 89) | def forward(self, x):
  class Upsample (line 95) | class Upsample(nn.Sequential):
    method __init__ (line 103) | def __init__(self, scale, num_feat):
  function flow_warp (line 117) | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', ali...
  function resize_flow (line 151) | def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_co...
  function pixel_unshuffle (line 190) | def pixel_unshuffle(x, scale):
  class DCNv2Pack (line 209) | class DCNv2Pack(ModulatedDeformConvPack):
    method forward (line 220) | def forward(self, x, feat):
  function _no_grad_trunc_normal_ (line 239) | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
  function trunc_normal_ (line 277) | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
  function _ntuple (line 304) | def _ntuple(n):

FILE: basicsr/archs/codeformer_arch.py
  function calc_mean_std (line 12) | def calc_mean_std(feat, eps=1e-5):
  function adaptive_instance_normalization (line 29) | def adaptive_instance_normalization(content_feat, style_feat):
  class PositionEmbeddingSine (line 46) | class PositionEmbeddingSine(nn.Module):
    method __init__ (line 52) | def __init__(self, num_pos_feats=64, temperature=10000, normalize=Fals...
    method forward (line 63) | def forward(self, x, mask=None):
  function _get_activation_fn (line 88) | def _get_activation_fn(activation):
  class TransformerSALayer (line 99) | class TransformerSALayer(nn.Module):
    method __init__ (line 100) | def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, acti...
    method with_pos_embed (line 115) | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    method forward (line 118) | def forward(self, tgt,
  class Fuse_sft_block (line 136) | class Fuse_sft_block(nn.Module):
    method __init__ (line 137) | def __init__(self, in_ch, out_ch):
    method forward (line 151) | def forward(self, enc_feat, dec_feat, w=1):
  class CodeFormer (line 161) | class CodeFormer(VQAutoEncoder):
    method __init__ (line 162) | def __init__(self, dim_embd=512, n_head=8, n_layers=9,
    method _init_weights (line 214) | def _init_weights(self, module):
    method forward (line 223) | def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):

FILE: basicsr/archs/rrdbnet_arch.py
  class ResidualDenseBlock (line 9) | class ResidualDenseBlock(nn.Module):
    method __init__ (line 19) | def __init__(self, num_feat=64, num_grow_ch=32):
    method forward (line 32) | def forward(self, x):
  class RRDB (line 42) | class RRDB(nn.Module):
    method __init__ (line 52) | def __init__(self, num_feat, num_grow_ch=32):
    method forward (line 58) | def forward(self, x):
  class RRDBNet (line 67) | class RRDBNet(nn.Module):
    method __init__ (line 87) | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_bl...
    method forward (line 105) | def forward(self, x):

FILE: basicsr/archs/vgg_arch.py
  function insert_bn (line 36) | def insert_bn(names):
  class VGGFeatureExtractor (line 55) | class VGGFeatureExtractor(nn.Module):
    method __init__ (line 78) | def __init__(self,
    method forward (line 141) | def forward(self, x):

FILE: basicsr/archs/vqgan_arch.py
  function normalize (line 14) | def normalize(in_channels):
  function swish (line 19) | def swish(x):
  class VectorQuantizer (line 24) | class VectorQuantizer(nn.Module):
    method __init__ (line 25) | def __init__(self, codebook_size, emb_dim, beta):
    method forward (line 33) | def forward(self, z):
    method get_codebook_feat (line 72) | def get_codebook_feat(self, indices, shape):
  class GumbelQuantizer (line 87) | class GumbelQuantizer(nn.Module):
    method __init__ (line 88) | def __init__(self, codebook_size, emb_dim, num_hiddens, straight_throu...
    method forward (line 98) | def forward(self, z):
  class Downsample (line 117) | class Downsample(nn.Module):
    method __init__ (line 118) | def __init__(self, in_channels):
    method forward (line 122) | def forward(self, x):
  class Upsample (line 129) | class Upsample(nn.Module):
    method __init__ (line 130) | def __init__(self, in_channels):
    method forward (line 134) | def forward(self, x):
  class ResBlock (line 141) | class ResBlock(nn.Module):
    method __init__ (line 142) | def __init__(self, in_channels, out_channels=None):
    method forward (line 153) | def forward(self, x_in):
  class AttnBlock (line 167) | class AttnBlock(nn.Module):
    method __init__ (line 168) | def __init__(self, in_channels):
    method forward (line 202) | def forward(self, x):
  class Encoder (line 229) | class Encoder(nn.Module):
    method __init__ (line 230) | def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, ...
    method forward (line 269) | def forward(self, x):
  class Generator (line 276) | class Generator(nn.Module):
    method __init__ (line 277) | def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_re...
    method forward (line 319) | def forward(self, x):
  class VQAutoEncoder (line 327) | class VQAutoEncoder(nn.Module):
    method __init__ (line 328) | def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blo...
    method forward (line 385) | def forward(self, x):
  class VQGANDiscriminator (line 395) | class VQGANDiscriminator(nn.Module):
    method __init__ (line 396) | def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
    method forward (line 433) | def forward(self, x):

FILE: basicsr/data/__init__.py
  function build_dataset (line 25) | def build_dataset(dataset_opt):
  function build_dataloader (line 40) | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sample...
  function worker_init_fn (line 96) | def worker_init_fn(worker_id, num_workers, rank, seed):

FILE: basicsr/data/data_sampler.py
  class EnlargedSampler (line 6) | class EnlargedSampler(Sampler):
    method __init__ (line 21) | def __init__(self, dataset, num_replicas, rank, ratio=1):
    method __iter__ (line 29) | def __iter__(self):
    method __len__ (line 44) | def __len__(self):
    method set_epoch (line 47) | def set_epoch(self, epoch):

FILE: basicsr/data/data_util.py
  function read_img_seq (line 13) | def read_img_seq(path, require_mod_crop=False, scale=1):
  function generate_frame_indices (line 37) | def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='...
  function paired_paths_from_lmdb (line 89) | def paired_paths_from_lmdb(folders, keys):
  function paired_paths_from_meta_info_file (line 148) | def paired_paths_from_meta_info_file(folders, keys, meta_info_file, file...
  function paired_paths_from_folder (line 192) | def paired_paths_from_folder(folders, keys, filename_tmpl):
  function paths_from_folder (line 228) | def paths_from_folder(folder):
  function paths_from_lmdb (line 243) | def paths_from_lmdb(folder):
  function generate_gaussian_kernel (line 259) | def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
  function duf_downsample (line 277) | def duf_downsample(x, kernel_size=13, scale=4):
  function brush_stroke_mask (line 310) | def brush_stroke_mask(img, color=(255,255,255)):
  function random_ff_mask (line 365) | def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70,...

FILE: basicsr/data/ffhq_blind_dataset.py
  class FFHQBlindDataset (line 19) | class FFHQBlindDataset(data.Dataset):
    method __init__ (line 21) | def __init__(self, opt):
    method color_jitter (line 117) | def color_jitter(img, shift):
    method color_jitter_pt (line 125) | def color_jitter_pt(img, brightness, contrast, saturation, hue):
    method get_component_locations (line 147) | def get_component_locations(self, name, status):
    method __getitem__ (line 179) | def __getitem__(self, index):
    method __len__ (line 298) | def __len__(self):

FILE: basicsr/data/ffhq_blind_joint_dataset.py
  class FFHQBlindJointDataset (line 18) | class FFHQBlindJointDataset(data.Dataset):
    method __init__ (line 20) | def __init__(self, opt):
    method color_jitter (line 109) | def color_jitter(img, shift):
    method color_jitter_pt (line 117) | def color_jitter_pt(img, brightness, contrast, saturation, hue):
    method get_component_locations (line 139) | def get_component_locations(self, name, status):
    method __getitem__ (line 171) | def __getitem__(self, index):
    method __len__ (line 323) | def __len__(self):

FILE: basicsr/data/gaussian_kernels.py
  function sigma_matrix2 (line 8) | def sigma_matrix2(sig_x, sig_y, theta):
  function mesh_grid (line 23) | def mesh_grid(kernel_size):
  function pdf2 (line 40) | def pdf2(sigma_matrix, grid):
  function cdf2 (line 54) | def cdf2(D, grid):
  function bivariate_skew_Gaussian (line 70) | def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=No...
  function mass_center_shift (line 96) | def mass_center_shift(kernel_size, kernel):
  function bivariate_skew_Gaussian_center (line 112) | def bivariate_skew_Gaussian_center(kernel_size,
  function bivariate_anisotropic_Gaussian (line 139) | def bivariate_anisotropic_Gaussian(kernel_size,
  function bivariate_isotropic_Gaussian (line 163) | def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
  function bivariate_generalized_Gaussian (line 181) | def bivariate_generalized_Gaussian(kernel_size,
  function bivariate_plateau_type1 (line 213) | def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid...
  function bivariate_plateau_type1_iso (line 237) | def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
  function random_bivariate_skew_Gaussian_center (line 259) | def random_bivariate_skew_Gaussian_center(kernel_size,
  function random_bivariate_anisotropic_Gaussian (line 310) | def random_bivariate_anisotropic_Gaussian(kernel_size,
  function random_bivariate_isotropic_Gaussian (line 354) | def random_bivariate_isotropic_Gaussian(kernel_size,
  function random_bivariate_generalized_Gaussian (line 385) | def random_bivariate_generalized_Gaussian(kernel_size,
  function random_bivariate_plateau_type1 (line 435) | def random_bivariate_plateau_type1(kernel_size,
  function random_bivariate_plateau_type1_iso (line 485) | def random_bivariate_plateau_type1_iso(kernel_size,
  function random_mixed_kernels (line 519) | def random_mixed_kernels(kernel_list,
  function show_one_kernel (line 588) | def show_one_kernel():
  function show_plateau_kernel (line 635) | def show_plateau_kernel():

FILE: basicsr/data/paired_image_dataset.py
  class PairedImageDataset (line 11) | class PairedImageDataset(data.Dataset):
    method __init__ (line 42) | def __init__(self, opt):
    method __getitem__ (line 67) | def __getitem__(self, index):
    method __len__ (line 100) | def __len__(self):

FILE: basicsr/data/prefetch_dataloader.py
  class PrefetchGenerator (line 7) | class PrefetchGenerator(threading.Thread):
    method __init__ (line 18) | def __init__(self, generator, num_prefetch_queue):
    method run (line 25) | def run(self):
    method __next__ (line 30) | def __next__(self):
    method __iter__ (line 36) | def __iter__(self):
  class PrefetchDataLoader (line 40) | class PrefetchDataLoader(DataLoader):
    method __init__ (line 55) | def __init__(self, num_prefetch_queue, **kwargs):
    method __iter__ (line 59) | def __iter__(self):
  class CPUPrefetcher (line 63) | class CPUPrefetcher():
    method __init__ (line 70) | def __init__(self, loader):
    method next (line 74) | def next(self):
    method reset (line 80) | def reset(self):
  class CUDAPrefetcher (line 84) | class CUDAPrefetcher():
    method __init__ (line 97) | def __init__(self, loader, opt):
    method preload (line 105) | def preload(self):
    method next (line 117) | def next(self):
    method reset (line 123) | def reset(self):

FILE: basicsr/data/transforms.py
  function mod_crop (line 5) | def mod_crop(img, scale):
  function paired_random_crop (line 25) | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
  function augment (line 80) | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=F...
  function img_rotate (line 147) | def img_rotate(img, angle, center=None, scale=1.0):

FILE: basicsr/losses/__init__.py
  function build_loss (line 14) | def build_loss(opt):

FILE: basicsr/losses/loss_util.py
  function reduce_loss (line 5) | def reduce_loss(loss, reduction):
  function weight_reduce_loss (line 25) | def weight_reduce_loss(loss, weight=None, reduction='mean'):
  function weighted_loss (line 57) | def weighted_loss(loss_func):

FILE: basicsr/losses/losses.py
  function l1_loss (line 16) | def l1_loss(pred, target):
  function mse_loss (line 21) | def mse_loss(pred, target):
  function charbonnier_loss (line 26) | def charbonnier_loss(pred, target, eps=1e-12):
  class L1Loss (line 31) | class L1Loss(nn.Module):
    method __init__ (line 40) | def __init__(self, loss_weight=1.0, reduction='mean'):
    method forward (line 48) | def forward(self, pred, target, weight=None, **kwargs):
  class MSELoss (line 60) | class MSELoss(nn.Module):
    method __init__ (line 69) | def __init__(self, loss_weight=1.0, reduction='mean'):
    method forward (line 77) | def forward(self, pred, target, weight=None, **kwargs):
  class CharbonnierLoss (line 89) | class CharbonnierLoss(nn.Module):
    method __init__ (line 104) | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
    method forward (line 113) | def forward(self, pred, target, weight=None, **kwargs):
  class WeightedTVLoss (line 125) | class WeightedTVLoss(L1Loss):
    method __init__ (line 132) | def __init__(self, loss_weight=1.0):
    method forward (line 135) | def forward(self, pred, weight=None):
  class PerceptualLoss (line 145) | class PerceptualLoss(nn.Module):
    method __init__ (line 168) | def __init__(self,
    method forward (line 198) | def forward(self, x, gt):
    method _gram_mat (line 240) | def _gram_mat(self, x):
  class LPIPSLoss (line 257) | class LPIPSLoss(nn.Module):
    method __init__ (line 258) | def __init__(self,
    method forward (line 274) | def forward(self, pred, target):
  class GANLoss (line 286) | class GANLoss(nn.Module):
    method __init__ (line 298) | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, l...
    method _wgan_loss (line 318) | def _wgan_loss(self, input, target):
    method _wgan_softplus_loss (line 330) | def _wgan_softplus_loss(self, input, target):
    method get_target_label (line 347) | def get_target_label(self, input, target_is_real):
    method forward (line 364) | def forward(self, input, target_is_real, is_disc=False):
  function r1_penalty (line 390) | def r1_penalty(real_pred, real_img):
  function g_path_regularize (line 407) | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
  function gradient_penalty_loss (line 419) | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=No...

FILE: basicsr/metrics/__init__.py
  function calculate_metric (line 9) | def calculate_metric(data, opt):

FILE: basicsr/metrics/metric_util.py
  function reorder_image (line 6) | def reorder_image(img, input_order='HWC'):
  function to_y_channel (line 32) | def to_y_channel(img):

FILE: basicsr/metrics/psnr_ssim.py
  function calculate_psnr (line 9) | def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_ch...
  function _ssim (line 49) | def _ssim(img1, img2):
  function calculate_ssim (line 84) | def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_ch...

FILE: basicsr/models/__init__.py
  function build_model (line 19) | def build_model(opt):

FILE: basicsr/models/base_model.py
  class BaseModel (line 14) | class BaseModel():
    method __init__ (line 17) | def __init__(self, opt):
    method feed_data (line 24) | def feed_data(self, data):
    method optimize_parameters (line 27) | def optimize_parameters(self):
    method get_current_visuals (line 30) | def get_current_visuals(self):
    method save (line 33) | def save(self, epoch, current_iter):
    method validation (line 37) | def validation(self, dataloader, current_iter, tb_logger, save_img=Fal...
    method model_ema (line 51) | def model_ema(self, decay=0.999):
    method get_current_log (line 60) | def get_current_log(self):
    method model_to_device (line 63) | def model_to_device(self, net):
    method get_optimizer (line 79) | def get_optimizer(self, optim_type, params, lr, **kwargs):
    method setup_schedulers (line 86) | def setup_schedulers(self):
    method get_bare_model (line 99) | def get_bare_model(self, net):
    method print_network (line 108) | def print_network(self, net):
    method _set_lr (line 126) | def _set_lr(self, lr_groups_l):
    method _get_init_lr (line 136) | def _get_init_lr(self):
    method update_learning_rate (line 144) | def update_learning_rate(self, current_iter, warmup_iter=-1):
    method get_current_learning_rate (line 167) | def get_current_learning_rate(self):
    method save_network (line 171) | def save_network(self, net, net_label, current_iter, param_key='params'):
    method _print_different_keys_loading (line 202) | def _print_different_keys_loading(self, crt_net, load_net, strict=True):
    method load_network (line 236) | def load_network(self, net, load_path, strict=True, param_key='params'):
    method save_training_state (line 264) | def save_training_state(self, epoch, current_iter):
    method resume_training (line 282) | def resume_training(self, resume_state):
    method reduce_loss_dict (line 297) | def reduce_loss_dict(self, loss_dict):

FILE: basicsr/models/codeformer_idx_model.py
  class CodeFormerIdxModel (line 15) | class CodeFormerIdxModel(SRModel):
    method feed_data (line 16) | def feed_data(self, data):
    method init_training_settings (line 27) | def init_training_settings(self):
    method setup_optimizers (line 71) | def setup_optimizers(self):
    method optimize_parameters (line 86) | def optimize_parameters(self, current_iter):
    method test (line 127) | def test(self):
    method dist_validation (line 140) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 145) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 197) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 208) | def get_current_visuals(self):
    method save (line 215) | def save(self, epoch, current_iter):

FILE: basicsr/models/codeformer_joint_model.py
  class CodeFormerJointModel (line 17) | class CodeFormerJointModel(SRModel):
    method feed_data (line 18) | def feed_data(self, data):
    method init_training_settings (line 30) | def init_training_settings(self):
    method calculate_adaptive_weight (line 107) | def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, di...
    method setup_optimizers (line 115) | def setup_optimizers(self):
    method gray_resize_for_identity (line 133) | def gray_resize_for_identity(self, out, size=128):
    method optimize_parameters (line 139) | def optimize_parameters(self, current_iter):
    method test (line 256) | def test(self):
    method dist_validation (line 269) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 274) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 326) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 337) | def get_current_visuals(self):
    method save (line 344) | def save(self, epoch, current_iter):

FILE: basicsr/models/codeformer_model.py
  class CodeFormerModel (line 16) | class CodeFormerModel(SRModel):
    method feed_data (line 17) | def feed_data(self, data):
    method init_training_settings (line 28) | def init_training_settings(self):
    method calculate_adaptive_weight (line 104) | def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, di...
    method setup_optimizers (line 112) | def setup_optimizers(self):
    method gray_resize_for_identity (line 131) | def gray_resize_for_identity(self, out, size=128):
    method optimize_parameters (line 137) | def optimize_parameters(self, current_iter):
    method test (line 237) | def test(self):
    method dist_validation (line 250) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 255) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 307) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 318) | def get_current_visuals(self):
    method save (line 325) | def save(self, epoch, current_iter):

FILE: basicsr/models/lr_scheduler.py
  class MultiStepRestartLR (line 6) | class MultiStepRestartLR(_LRScheduler):
    method __init__ (line 19) | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), r...
    method get_lr (line 27) | def get_lr(self):
  function get_position_from_periods (line 36) | def get_position_from_periods(iteration, cumulative_period):
  class CosineAnnealingRestartLR (line 57) | class CosineAnnealingRestartLR(_LRScheduler):
    method __init__ (line 77) | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=...
    method get_lr (line 86) | def get_lr(self):

FILE: basicsr/models/sr_model.py
  class SRModel (line 14) | class SRModel(BaseModel):
    method __init__ (line 17) | def __init__(self, opt):
    method init_training_settings (line 34) | def init_training_settings(self):
    method setup_optimizers (line 72) | def setup_optimizers(self):
    method feed_data (line 86) | def feed_data(self, data):
    method optimize_parameters (line 91) | def optimize_parameters(self, current_iter):
    method test (line 120) | def test(self):
    method dist_validation (line 131) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 135) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 186) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 196) | def get_current_visuals(self):
    method save (line 204) | def save(self, epoch, current_iter):

FILE: basicsr/models/vqgan_model.py
  class VQGANModel (line 16) | class VQGANModel(SRModel):
    method feed_data (line 17) | def feed_data(self, data):
    method init_training_settings (line 22) | def init_training_settings(self):
    method calculate_adaptive_weight (line 85) | def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, di...
    method adopt_weight (line 93) | def adopt_weight(self, weight, global_step, threshold=0, value=0.):
    method setup_optimizers (line 98) | def setup_optimizers(self):
    method optimize_parameters (line 117) | def optimize_parameters(self, current_iter):
    method test (line 192) | def test(self):
    method dist_validation (line 205) | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
    method nondist_validation (line 210) | def nondist_validation(self, dataloader, current_iter, tb_logger, save...
    method _log_validation_metric_values (line 262) | def _log_validation_metric_values(self, current_iter, dataset_name, tb...
    method get_current_visuals (line 273) | def get_current_visuals(self):
    method save (line 279) | def save(self, epoch, current_iter):

FILE: basicsr/ops/dcn/deform_conv.py
  class DeformConvFunction (line 27) | class DeformConvFunction(Function):
    method forward (line 30) | def forward(ctx,
    method backward (line 69) | def backward(ctx, grad_output):
    method _output_size (line 101) | def _output_size(input, weight, padding, dilation, stride):
  class ModulatedDeformConvFunction (line 115) | class ModulatedDeformConvFunction(Function):
    method forward (line 118) | def forward(ctx,
    method backward (line 152) | def backward(ctx, grad_output):
    method _infer_shape (line 172) | def _infer_shape(ctx, input, weight):
  class DeformConv (line 186) | class DeformConv(nn.Module):
    method __init__ (line 188) | def __init__(self,
    method reset_parameters (line 223) | def reset_parameters(self):
    method forward (line 230) | def forward(self, x, offset):
  class DeformConvPack (line 246) | class DeformConvPack(DeformConv):
    method __init__ (line 264) | def __init__(self, *args, **kwargs):
    method init_offset (line 277) | def init_offset(self):
    method forward (line 281) | def forward(self, x):
  class ModulatedDeformConv (line 287) | class ModulatedDeformConv(nn.Module):
    method __init__ (line 289) | def __init__(self,
    method init_weights (line 320) | def init_weights(self):
    method forward (line 329) | def forward(self, x, offset, mask):
  class ModulatedDeformConvPack (line 334) | class ModulatedDeformConvPack(ModulatedDeformConv):
    method __init__ (line 352) | def __init__(self, *args, **kwargs):
    method init_weights (line 365) | def init_weights(self):
    method forward (line 371) | def forward(self, x):

FILE: basicsr/ops/dcn/src/deform_conv_cuda.cpp
  function shape_check (line 62) | void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOu...
  function deform_conv_forward_cuda (line 152) | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
  function deform_conv_backward_input_cuda (line 262) | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
  function deform_conv_backward_parameters_cuda (line 376) | int deform_conv_backward_parameters_cuda(
  function modulated_deform_conv_cuda_forward (line 490) | void modulated_deform_conv_cuda_forward(
  function modulated_deform_conv_cuda_backward (line 571) | void modulated_deform_conv_cuda_backward(

FILE: basicsr/ops/dcn/src/deform_conv_ext.cpp
  function deform_conv_forward (line 52) | int deform_conv_forward(at::Tensor input, at::Tensor weight,
  function deform_conv_backward_input (line 70) | int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
  function deform_conv_backward_parameters (line 89) | int deform_conv_backward_parameters(
  function modulated_deform_conv_forward (line 107) | void modulated_deform_conv_forward(
  function modulated_deform_conv_backward (line 127) | void modulated_deform_conv_backward(
  function PYBIND11_MODULE (line 150) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/fused_act/fused_act.py
  class FusedLeakyReLUFunctionBackward (line 24) | class FusedLeakyReLUFunctionBackward(Function):
    method forward (line 27) | def forward(ctx, grad_output, out, negative_slope, scale):
    method backward (line 46) | def backward(ctx, gradgrad_input, gradgrad_bias):
  class FusedLeakyReLUFunction (line 54) | class FusedLeakyReLUFunction(Function):
    method forward (line 57) | def forward(ctx, input, bias, negative_slope, scale):
    method backward (line 67) | def backward(ctx, grad_output):
  class FusedLeakyReLU (line 75) | class FusedLeakyReLU(nn.Module):
    method __init__ (line 77) | def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
    method forward (line 84) | def forward(self, input):
  function fused_leaky_relu (line 88) | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):

FILE: basicsr/ops/fused_act/src/fused_bias_act.cpp
  function fused_bias_act (line 14) | torch::Tensor fused_bias_act(const torch::Tensor& input,
  function PYBIND11_MODULE (line 24) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
  function upfirdn2d (line 13) | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor&...
  function PYBIND11_MODULE (line 22) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: basicsr/ops/upfirdn2d/upfirdn2d.py
  class UpFirDn2dBackward (line 24) | class UpFirDn2dBackward(Function):
    method forward (line 27) | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pa...
    method backward (line 67) | def backward(ctx, gradgrad_input):
  class UpFirDn2d (line 91) | class UpFirDn2d(Function):
    method forward (line 94) | def forward(ctx, input, kernel, up, down, pad):
    method backward (line 129) | def backward(ctx, grad_output):
  function upfirdn2d (line 147) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
  function upfirdn2d_native (line 156) | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, ...

FILE: basicsr/setup.py
  function readme (line 15) | def readme():
  function get_git_hash (line 21) | def get_git_hash():
  function get_hash (line 46) | def get_hash():
  function write_version_py (line 61) | def write_version_py():
  function get_version (line 78) | def get_version():
  function make_cuda_ext (line 84) | def make_cuda_ext(name, module, sources, sources_cuda=None):
  function get_requirements (line 111) | def get_requirements(filename='requirements.txt'):

FILE: basicsr/train.py
  function parse_options (line 24) | def parse_options(root_path, is_train=True):
  function init_loggers (line 55) | def init_loggers(opt):
  function create_train_val_dataloader (line 71) | def create_train_val_dataloader(opt, logger):
  function train_pipeline (line 110) | def train_pipeline(root_path):

FILE: basicsr/utils/dist_util.py
  function init_dist (line 10) | def init_dist(launcher, backend='nccl', **kwargs):
  function _init_dist_pytorch (line 21) | def _init_dist_pytorch(backend, **kwargs):
  function _init_dist_slurm (line 28) | def _init_dist_slurm(backend, port=None):
  function get_dist_info (line 60) | def get_dist_info():
  function master_only (line 74) | def master_only(func):

FILE: basicsr/utils/download_util.py
  function download_file_from_google_drive (line 11) | def download_file_from_google_drive(file_id, save_path):
  function get_confirm_token (line 41) | def get_confirm_token(response):
  function save_response_content (line 48) | def save_response_content(response, destination, file_size=None, chunk_s...
  function load_file_from_url (line 69) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):

FILE: basicsr/utils/file_client.py
  class BaseStorageBackend (line 5) | class BaseStorageBackend(metaclass=ABCMeta):
    method get (line 14) | def get(self, filepath):
    method get_text (line 18) | def get_text(self, filepath):
  class MemcachedBackend (line 22) | class MemcachedBackend(BaseStorageBackend):
    method __init__ (line 32) | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
    method get (line 47) | def get(self, filepath):
    method get_text (line 54) | def get_text(self, filepath):
  class HardDiskBackend (line 58) | class HardDiskBackend(BaseStorageBackend):
    method get (line 61) | def get(self, filepath):
    method get_text (line 67) | def get_text(self, filepath):
  class LmdbBackend (line 74) | class LmdbBackend(BaseStorageBackend):
    method __init__ (line 94) | def __init__(self, db_paths, client_keys='default', readonly=True, loc...
    method get (line 114) | def get(self, filepath, client_key):
    method get_text (line 128) | def get_text(self, filepath):
  class FileClient (line 132) | class FileClient(object):
    method __init__ (line 151) | def __init__(self, backend='disk', **kwargs):
    method get (line 158) | def get(self, filepath, client_key='default'):
    method get_text (line 166) | def get_text(self, filepath):

FILE: basicsr/utils/img_util.py
  function img2tensor (line 9) | def img2tensor(imgs, bgr2rgb=True, float32=True):
  function tensor2img (line 38) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  function tensor2img_fast (line 97) | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
  function imfrombytes (line 114) | def imfrombytes(content, flag='color', float32=False):
  function imwrite (line 135) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function crop_border (line 154) | def crop_border(imgs, crop_border):

FILE: basicsr/utils/lmdb_util.py
  function make_lmdb_from_imgs (line 9) | def make_lmdb_from_imgs(data_path,
  function read_img_worker (line 132) | def read_img_worker(path, key, compress_level):
  class LmdbMaker (line 156) | class LmdbMaker():
    method __init__ (line 167) | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_l...
    method put (line 182) | def put(self, img_byte, key, img_shape):
    method close (line 193) | def close(self):

FILE: basicsr/utils/logger.py
  class MessageLogger (line 10) | class MessageLogger():
    method __init__ (line 22) | def __init__(self, opt, start_iter=1, tb_logger=None):
    method __call__ (line 33) | def __call__(self, log_vars):
  function init_tb_logger (line 78) | def init_tb_logger(log_dir):
  function init_wandb_logger (line 85) | def init_wandb_logger(opt):
  function get_root_logger (line 105) | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_f...
  function get_env_info (line 145) | def get_env_info():

FILE: basicsr/utils/matlab_functions.py
  function cubic (line 6) | def cubic(x):
  function calculate_weights_indices (line 16) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
  function imresize (line 86) | def imresize(img, scale, antialiasing=True):
  function rgb2ycbcr (line 169) | def rgb2ycbcr(img, y_only=False):
  function bgr2ycbcr (line 202) | def bgr2ycbcr(img, y_only=False):
  function ycbcr2rgb (line 235) | def ycbcr2rgb(img):
  function ycbcr2bgr (line 264) | def ycbcr2bgr(img):
  function _convert_input_type_range (line 293) | def _convert_input_type_range(img):
  function _convert_output_type_range (line 320) | def _convert_output_type_range(img, dst_type):

FILE: basicsr/utils/misc.py
  function gpu_is_available (line 15) | def gpu_is_available():
  function get_device (line 21) | def get_device(gpu_id=None):
  function set_random_seed (line 35) | def set_random_seed(seed):
  function get_time_str (line 44) | def get_time_str():
  function mkdir_and_rename (line 48) | def mkdir_and_rename(path):
  function make_exp_dirs (line 62) | def make_exp_dirs(opt):
  function scandir (line 74) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
  function check_resume (line 116) | def check_resume(opt, resume_iter):
  function sizeof_fmt (line 143) | def sizeof_fmt(size, suffix='B'):

FILE: basicsr/utils/options.py
  function ordered_yaml (line 7) | def ordered_yaml():
  function parse (line 32) | def parse(opt_path, root_path, is_train=True):
  function dict2str (line 90) | def dict2str(opt, indent_level=1):

FILE: basicsr/utils/realesrgan_utils.py
  class RealESRGANer (line 14) | class RealESRGANer():
    method __init__ (line 29) | def __init__(self,
    method pre_process (line 71) | def pre_process(self, img):
    method process (line 96) | def process(self):
    method tile_process (line 100) | def tile_process(self):
    method post_process (line 165) | def post_process(self):
    method enhance (line 177) | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
  class PrefetchReader (line 255) | class PrefetchReader(threading.Thread):
    method __init__ (line 263) | def __init__(self, img_list, num_prefetch_queue):
    method run (line 268) | def run(self):
    method __next__ (line 275) | def __next__(self):
    method __iter__ (line 281) | def __iter__(self):
  class IOConsumer (line 285) | class IOConsumer(threading.Thread):
    method __init__ (line 287) | def __init__(self, opt, que, qid):
    method run (line 293) | def run(self):

FILE: basicsr/utils/registry.py
  class Registry (line 4) | class Registry():
    method __init__ (line 30) | def __init__(self, name):
    method _do_register (line 38) | def _do_register(self, name, obj):
    method register (line 43) | def register(self, obj=None):
    method get (line 62) | def get(self, name):
    method __contains__ (line 68) | def __contains__(self, name):
    method __iter__ (line 71) | def __iter__(self):
    method keys (line 74) | def keys(self):

FILE: basicsr/utils/video_util.py
  function get_video_meta_info (line 17) | def get_video_meta_info(video_path):
  class VideoReader (line 29) | class VideoReader:
    method __init__ (line 30) | def __init__(self, video_path):
    method get_resolution (line 52) | def get_resolution(self):
    method get_fps (line 55) | def get_fps(self):
    method get_audio (line 60) | def get_audio(self):
    method __len__ (line 63) | def __len__(self):
    method get_frame_from_stream (line 66) | def get_frame_from_stream(self):
    method get_frame_from_list (line 73) | def get_frame_from_list(self):
    method get_frame (line 80) | def get_frame(self):
    method close (line 84) | def close(self):
  class VideoWriter (line 89) | class VideoWriter:
    method __init__ (line 90) | def __init__(self, video_save_path, height, width, fps, audio):
    method write_frame (line 113) | def write_frame(self, frame):
    method close (line 123) | def close(self):

FILE: facelib/detection/__init__.py
  function init_detection_model (line 14) | def init_detection_model(model_name, half=False, device='cuda'):
  function init_retinaface_model (line 25) | def init_retinaface_model(model_name, half=False, device='cuda'):
  function init_yolov5face_model (line 49) | def init_yolov5face_model(model_name, device='cuda'):

FILE: facelib/detection/align_trans.py
  class FaceWarpException (line 13) | class FaceWarpException(Exception):
    method __str__ (line 15) | def __str__(self):
  function get_reference_facial_points (line 19) | def get_reference_facial_points(output_size=None, inner_padding_factor=0...
  function get_affine_transform_matrix (line 112) | def get_affine_transform_matrix(src_pts, dst_pts):
  function warp_and_crop_face (line 145) | def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_siz...

FILE: facelib/detection/matlab_cp2tform.py
  class MatlabCp2tormException (line 7) | class MatlabCp2tormException(Exception):
    method __str__ (line 9) | def __str__(self):
  function tformfwd (line 13) | def tformfwd(trans, uv):
  function tforminv (line 37) | def tforminv(trans, uv):
  function findNonreflectiveSimilarity (line 60) | def findNonreflectiveSimilarity(uv, xy, options=None):
  function findSimilarity (line 94) | def findSimilarity(uv, xy, options=None):
  function get_similarity_transform (line 130) | def get_similarity_transform(src_pts, dst_pts, reflective=True):
  function cvt_tform_mat_for_cv2 (line 170) | def cvt_tform_mat_for_cv2(trans):
  function get_similarity_transform_for_cv2 (line 198) | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):

FILE: facelib/detection/retinaface/retinaface.py
  function generate_config (line 19) | def generate_config(network_name):
  class RetinaFace (line 75) | class RetinaFace(nn.Module):
    method __init__ (line 77) | def __init__(self, network_name='resnet50', half=False, phase='test'):
    method forward (line 122) | def forward(self, inputs):
    method __detect_faces (line 147) | def __detect_faces(self, inputs):
    method transform (line 167) | def transform(self, image, use_origin_size):
    method detect_faces (line 194) | def detect_faces(
    method __align_multi (line 241) | def __align_multi(self, image, boxes, landmarks, limit=None):
    method align_multi (line 259) | def align_multi(self, img, conf_threshold=0.8, limit=None):
    method batched_transform (line 267) | def batched_transform(self, frames, use_origin_size):
    method batched_detect_faces (line 310) | def batched_detect_faces(self, frames, conf_threshold=0.8, nms_thresho...

FILE: facelib/detection/retinaface/retinaface_net.py
  function conv_bn (line 6) | def conv_bn(inp, oup, stride=1, leaky=0):
  function conv_bn_no_relu (line 12) | def conv_bn_no_relu(inp, oup, stride):
  function conv_bn1X1 (line 19) | def conv_bn1X1(inp, oup, stride, leaky=0):
  function conv_dw (line 25) | def conv_dw(inp, oup, stride, leaky=0.1):
  class SSH (line 36) | class SSH(nn.Module):
    method __init__ (line 38) | def __init__(self, in_channel, out_channel):
    method forward (line 52) | def forward(self, input):
  class FPN (line 66) | class FPN(nn.Module):
    method __init__ (line 68) | def __init__(self, in_channels_list, out_channels):
    method forward (line 80) | def forward(self, input):
  class MobileNetV1 (line 100) | class MobileNetV1(nn.Module):
    method __init__ (line 102) | def __init__(self):
    method forward (line 127) | def forward(self, x):
  class ClassHead (line 138) | class ClassHead(nn.Module):
    method __init__ (line 140) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 145) | def forward(self, x):
  class BboxHead (line 152) | class BboxHead(nn.Module):
    method __init__ (line 154) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 158) | def forward(self, x):
  class LandmarkHead (line 165) | class LandmarkHead(nn.Module):
    method __init__ (line 167) | def __init__(self, inchannels=512, num_anchors=3):
    method forward (line 171) | def forward(self, x):
  function make_class_head (line 178) | def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
  function make_bbox_head (line 185) | def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
  function make_landmark_head (line 192) | def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):

FILE: facelib/detection/retinaface/retinaface_utils.py
  class PriorBox (line 8) | class PriorBox(object):
    method __init__ (line 10) | def __init__(self, cfg, image_size=None, phase='train'):
    method forward (line 19) | def forward(self):
  function py_cpu_nms (line 39) | def py_cpu_nms(dets, thresh):
  function point_form (line 50) | def point_form(boxes):
  function center_size (line 65) | def center_size(boxes):
  function intersect (line 79) | def intersect(box_a, box_b):
  function jaccard (line 98) | def jaccard(box_a, box_b):
  function matrix_iou (line 117) | def matrix_iou(a, b):
  function matrix_iof (line 130) | def matrix_iof(a, b):
  function match (line 142) | def match(threshold, truths, priors, variances, labels, landms, loc_t, c...
  function encode (line 200) | def encode(matched, priors, variances):
  function encode_landm (line 224) | def encode_landm(matched, priors, variances):
  function decode (line 254) | def decode(loc, priors, variances):
  function decode_landm (line 274) | def decode_landm(pre, priors, variances):
  function batched_decode (line 297) | def batched_decode(b_loc, priors, variances):
  function batched_decode_landm (line 320) | def batched_decode_landm(pre, priors, variances):
  function log_sum_exp (line 343) | def log_sum_exp(x):
  function nms (line 357) | def nms(boxes, scores, overlap=0.5, top_k=200):

FILE: facelib/detection/yolov5face/face_detector.py
  function isListempty (line 22) | def isListempty(inList):
  class YoloDetector (line 27) | class YoloDetector:
    method __init__ (line 28) | def __init__(
    method _preprocess (line 48) | def _preprocess(self, imgs):
    method _postprocess (line 69) | def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
    method detect_faces (line 104) | def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
    method __call__ (line 140) | def __call__(self, *args):

FILE: facelib/detection/yolov5face/models/common.py
  function autopad (line 18) | def autopad(k, p=None):  # kernel, padding
  function channel_shuffle (line 25) | def channel_shuffle(x, groups):
  function DWConv (line 37) | def DWConv(c1, c2, k=1, s=1, act=True):
  class Conv (line 42) | class Conv(nn.Module):
    method __init__ (line 44) | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in,...
    method forward (line 50) | def forward(self, x):
    method fuseforward (line 53) | def fuseforward(self, x):
  class StemBlock (line 57) | class StemBlock(nn.Module):
    method __init__ (line 58) | def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
    method forward (line 66) | def forward(self, x):
  class Bottleneck (line 74) | class Bottleneck(nn.Module):
    method __init__ (line 76) | def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_ou...
    method forward (line 83) | def forward(self, x):
  class BottleneckCSP (line 87) | class BottleneckCSP(nn.Module):
    method __init__ (line 89) | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ...
    method forward (line 100) | def forward(self, x):
  class C3 (line 106) | class C3(nn.Module):
    method __init__ (line 108) | def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ...
    method forward (line 116) | def forward(self, x):
  class ShuffleV2Block (line 120) | class ShuffleV2Block(nn.Module):
    method __init__ (line 121) | def __init__(self, inp, oup, stride):
    method depthwise_conv (line 160) | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
    method forward (line 163) | def forward(self, x):
  class SPP (line 173) | class SPP(nn.Module):
    method __init__ (line 175) | def __init__(self, c1, c2, k=(5, 9, 13)):
    method forward (line 182) | def forward(self, x):
  class Focus (line 187) | class Focus(nn.Module):
    method __init__ (line 189) | def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in,...
    method forward (line 193) | def forward(self, x):  # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  class Concat (line 197) | class Concat(nn.Module):
    method __init__ (line 199) | def __init__(self, dimension=1):
    method forward (line 203) | def forward(self, x):
  class NMS (line 207) | class NMS(nn.Module):
    method forward (line 213) | def forward(self, x):
  class AutoShape (line 217) | class AutoShape(nn.Module):
    method __init__ (line 224) | def __init__(self, model):
    method autoshape (line 228) | def autoshape(self):
    method forward (line 232) | def forward(self, imgs, size=640, augment=False, profile=False):
  class Detections (line 275) | class Detections:
    method __init__ (line 277) | def __init__(self, imgs, pred, names=None):
    method __len__ (line 290) | def __len__(self):
    method tolist (line 293) | def tolist(self):

FILE: facelib/detection/yolov5face/models/experimental.py
  class CrossConv (line 10) | class CrossConv(nn.Module):
    method __init__ (line 12) | def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
    method forward (line 20) | def forward(self, x):
  class MixConv2d (line 24) | class MixConv2d(nn.Module):
    method __init__ (line 26) | def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
    method forward (line 44) | def forward(self, x):

FILE: facelib/detection/yolov5face/models/yolo.py
  class Detect (line 29) | class Detect(nn.Module):
    method __init__ (line 33) | def __init__(self, nc=80, anchors=(), ch=()):  # detection layer
    method forward (line 46) | def forward(self, x):
    method _make_grid (line 89) | def _make_grid(nx=20, ny=20):
  class Model (line 95) | class Model(nn.Module):
    method __init__ (line 96) | def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None):  # model, input...
    method forward (line 120) | def forward(self, x):
    method forward_once (line 123) | def forward_once(self, x):
    method _initialize_biases (line 134) | def _initialize_biases(self, cf=None):  # initialize biases into Detec...
    method _print_biases (line 143) | def _print_biases(self):
    method fuse (line 149) | def fuse(self):  # fuse model Conv2d() + BatchNorm2d() layers
    method nms (line 160) | def nms(self, mode=True):  # add or remove NMS module
    method autoshape (line 174) | def autoshape(self):  # add autoShape module
  function parse_model (line 181) | def parse_model(d, ch):  # model_dict, input_channels(3)

FILE: facelib/detection/yolov5face/utils/autoanchor.py
  function check_anchor_order (line 4) | def check_anchor_order(m):

FILE: facelib/detection/yolov5face/utils/datasets.py
  function letterbox (line 5) | def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=Tru...

FILE: facelib/detection/yolov5face/utils/general.py
  function check_img_size (line 9) | def check_img_size(img_size, s=32):
  function make_divisible (line 17) | def make_divisible(x, divisor):
  function xyxy2xywh (line 22) | def xyxy2xywh(x):
  function xywh2xyxy (line 32) | def xywh2xyxy(x):
  function scale_coords (line 42) | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
  function clip_coords (line 58) | def clip_coords(boxes, img_shape):
  function box_iou (line 66) | def box_iou(box1, box2):
  function non_max_suppression_face (line 89) | def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45...
  function non_max_suppression (line 168) | def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, cla...
  function scale_coords_landmarks (line 249) | def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):

FILE: facelib/detection/yolov5face/utils/torch_utils.py
  function fuse_conv_and_bn (line 5) | def fuse_conv_and_bn(conv, bn):
  function copy_attr (line 34) | def copy_attr(a, b, include=(), exclude=()):

FILE: facelib/parsing/__init__.py
  function init_parsing_model (line 8) | def init_parsing_model(model_name='bisenet', half=False, device='cuda'):

FILE: facelib/parsing/bisenet.py
  class ConvBNReLU (line 8) | class ConvBNReLU(nn.Module):
    method __init__ (line 10) | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
    method forward (line 15) | def forward(self, x):
  class BiSeNetOutput (line 21) | class BiSeNetOutput(nn.Module):
    method __init__ (line 23) | def __init__(self, in_chan, mid_chan, num_class):
    method forward (line 28) | def forward(self, x):
  class AttentionRefinementModule (line 34) | class AttentionRefinementModule(nn.Module):
    method __init__ (line 36) | def __init__(self, in_chan, out_chan):
    method forward (line 43) | def forward(self, x):
  class ContextPath (line 53) | class ContextPath(nn.Module):
    method __init__ (line 55) | def __init__(self):
    method forward (line 64) | def forward(self, x):
  class FeatureFusionModule (line 87) | class FeatureFusionModule(nn.Module):
    method __init__ (line 89) | def __init__(self, in_chan, out_chan):
    method forward (line 97) | def forward(self, fsp, fcp):
  class BiSeNet (line 110) | class BiSeNet(nn.Module):
    method __init__ (line 112) | def __init__(self, num_class):
    method forward (line 120) | def forward(self, x, return_feat=False):

FILE: facelib/parsing/parsenet.py
  class NormLayer (line 8) | class NormLayer(nn.Module):
    method __init__ (line 16) | def __init__(self, channels, normalize_shape=None, norm_type='bn'):
    method forward (line 35) | def forward(self, x, ref=None):
  class ReluLayer (line 42) | class ReluLayer(nn.Module):
    method __init__ (line 54) | def __init__(self, channels, relu_type='relu'):
    method forward (line 70) | def forward(self, x):
  class ConvLayer (line 74) | class ConvLayer(nn.Module):
    method __init__ (line 76) | def __init__(self,
    method forward (line 103) | def forward(self, x):
  class ResidualBlock (line 113) | class ResidualBlock(nn.Module):
    method __init__ (line 118) | def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', sca...
    method forward (line 132) | def forward(self, x):
  class ParseNet (line 140) | class ParseNet(nn.Module):
    method __init__ (line 142) | def __init__(self,
    method forward (line 188) | def forward(self, x):

FILE: facelib/parsing/resnet.py
  function conv3x3 (line 5) | def conv3x3(in_planes, out_planes, stride=1):
  class BasicBlock (line 10) | class BasicBlock(nn.Module):
    method __init__ (line 12) | def __init__(self, in_chan, out_chan, stride=1):
    method forward (line 26) | def forward(self, x):
  function create_layer_basic (line 41) | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
  class ResNet18 (line 48) | class ResNet18(nn.Module):
    method __init__ (line 50) | def __init__(self):
    method forward (line 60) | def forward(self, x):

FILE: facelib/utils/face_restoration_helper.py
  function get_largest_face (line 18) | def get_largest_face(det_faces, h, w):
  function get_center_face (line 40) | def get_center_face(det_faces, h=0, w=0, center=None):
  class FaceRestoreHelper (line 54) | class FaceRestoreHelper(object):
    method __init__ (line 57) | def __init__(self,
    method set_upscale_factor (line 127) | def set_upscale_factor(self, upscale_factor):
    method read_image (line 130) | def read_image(self, img):
    method init_dlib (line 152) | def init_dlib(self, detection_path, landmark5_path):
    method get_face_landmarks_5_dlib (line 164) | def get_face_landmarks_5_dlib(self,
    method get_face_landmarks_5 (line 196) | def get_face_landmarks_5(self,
    method align_warp_face (line 319) | def align_warp_face(self, save_cropped_path=None, border_mode='constan...
    method get_inverse_affine (line 351) | def get_inverse_affine(self, save_inverse_affine_path=None):
    method add_restored_face (line 364) | def add_restored_face(self, restored_face, input_face=None):
    method paste_faces_to_input_image (line 372) | def paste_faces_to_input_image(self, save_path=None, upsample_img=None...
    method clean_all (line 518) | def clean_all(self):

FILE: facelib/utils/face_utils.py
  function compute_increased_bbox (line 6) | def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
  function get_valid_bboxes (line 23) | def get_valid_bboxes(bboxes, h, w):
  function align_crop_face_landmarks (line 31) | def align_crop_face_landmarks(img,
  function paste_face_back (line 190) | def paste_face_back(img, face, inverse_affine):

FILE: facelib/utils/misc.py
  function download_pretrained_models (line 14) | def download_pretrained_models(file_ids, save_path_root):
  function imwrite (line 38) | def imwrite(img, file_path, params=None, auto_mkdir=True):
  function img2tensor (line 57) | def img2tensor(imgs, bgr2rgb=True, float32=True):
  function load_file_from_url (line 86) | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
  function scandir (line 106) | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
  function is_gray (line 146) | def is_gray(img, threshold=10):
  function rgb2gray (line 162) | def rgb2gray(img, out_channel=3):
  function bgr2gray (line 169) | def bgr2gray(img, out_channel=3):
  function calc_mean_std (line 177) | def calc_mean_std(feat, eps=1e-5):
  function adain_npy (line 191) | def adain_npy(content_feat, style_feat):

FILE: inference_codeformer.py
  function set_realesrgan (line 19) | def set_realesrgan():

FILE: scripts/crop_align_face.py
  function get_landmark (line 38) | def get_landmark(filepath, only_keep_largest=True):
  function align_face (line 78) | def align_face(filepath, out_path):

FILE: scripts/download_pretrained_models.py
  function download_pretrained_models (line 8) | def download_pretrained_models(method, file_urls):

FILE: scripts/download_pretrained_models_from_gdrive.py
  function download_pretrained_models (line 9) | def download_pretrained_models(method, file_ids):

FILE: web-demos/hugging_face/app.py
  function imread (line 62) | def imread(img_path):
  function set_realesrgan (line 68) | def set_realesrgan():
  function inference (line 107) | def inference(image, background_enhance, face_upsample, upscale, codefor...

FILE: web-demos/replicate/predict.py
  class Predictor (line 25) | class Predictor(BasePredictor):
    method setup (line 26) | def setup(self):
    method predict (line 44) | def predict(
  function imread (line 156) | def imread(img_path):
  function set_realesrgan (line 162) | def set_realesrgan():
Condensed preview — 115 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (737K chars).
[
  {
    "path": ".gitignore",
    "chars": 1409,
    "preview": ".vscode\n\n# ignored files\nversion.py\n\n# ignored files with suffix\n*.html\n# *.png\n# *.jpeg\n# *.jpg\n*.pt\n*.gif\n*.pth\n*.dat\n"
  },
  {
    "path": "LICENSE",
    "chars": 1717,
    "preview": "S-Lab License 1.0\n\nCopyright 2022 S-Lab\n\nRedistribution and use for non-commercial purpose in source and \nbinary forms, "
  },
  {
    "path": "README.md",
    "chars": 11926,
    "preview": "<p align=\"center\">\n  <img src=\"assets/CodeFormer_logo.png\" height=110>\n</p>\n\n## Towards Robust Blind Face Restoration wi"
  },
  {
    "path": "basicsr/VERSION",
    "chars": 6,
    "preview": "1.3.2\n"
  },
  {
    "path": "basicsr/__init__.py",
    "chars": 266,
    "preview": "# https://github.com/xinntao/BasicSR\n# flake8: noqa\nfrom .archs import *\nfrom .data import *\nfrom .losses import *\nfrom "
  },
  {
    "path": "basicsr/archs/__init__.py",
    "chars": 886,
    "preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
  },
  {
    "path": "basicsr/archs/arcface_arch.py",
    "chars": 8074,
    "preview": "import torch.nn as nn\nfrom basicsr.utils.registry import ARCH_REGISTRY\n\n\ndef conv3x3(inplanes, outplanes, stride=1):\n   "
  },
  {
    "path": "basicsr/archs/arch_util.py",
    "chars": 11431,
    "preview": "import collections.abc\nimport math\nimport torch\nimport torchvision\nimport warnings\nfrom distutils.version import LooseVe"
  },
  {
    "path": "basicsr/archs/codeformer_arch.py",
    "chars": 11117,
    "preview": "import math\nimport numpy as np\nimport torch\nfrom torch import nn, Tensor\nimport torch.nn.functional as F\nfrom typing imp"
  },
  {
    "path": "basicsr/archs/rrdbnet_arch.py",
    "chars": 4620,
    "preview": "import torch\nfrom torch import nn as nn\nfrom torch.nn import functional as F\n\nfrom basicsr.utils.registry import ARCH_RE"
  },
  {
    "path": "basicsr/archs/vgg_arch.py",
    "chars": 6142,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom torch import nn as nn\nfrom torchvision.models import vgg"
  },
  {
    "path": "basicsr/archs/vqgan_arch.py",
    "chars": 15428,
    "preview": "'''\nVQGAN code, adapted from the original created by the Unleashing Transformers authors:\nhttps://github.com/samb-t/unle"
  },
  {
    "path": "basicsr/data/__init__.py",
    "chars": 4254,
    "preview": "import importlib\nimport numpy as np\nimport random\nimport torch\nimport torch.utils.data\nfrom copy import deepcopy\nfrom fu"
  },
  {
    "path": "basicsr/data/data_sampler.py",
    "chars": 1639,
    "preview": "import math\nimport torch\nfrom torch.utils.data.sampler import Sampler\n\n\nclass EnlargedSampler(Sampler):\n    \"\"\"Sampler t"
  },
  {
    "path": "basicsr/data/data_util.py",
    "chars": 15045,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport torch\nfrom os import path as osp\nfrom PIL import Image, ImageDraw\nfrom "
  },
  {
    "path": "basicsr/data/ffhq_blind_dataset.py",
    "chars": 13437,
    "preview": "import cv2\nimport math\nimport random\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nfrom PIL impo"
  },
  {
    "path": "basicsr/data/ffhq_blind_joint_dataset.py",
    "chars": 14689,
    "preview": "import cv2\nimport math\nimport random\nimport numpy as np\nimport os.path as osp\nfrom scipy.io import loadmat\nimport torch\n"
  },
  {
    "path": "basicsr/data/gaussian_kernels.py",
    "chars": 25543,
    "preview": "import math\nimport numpy as np\nimport random\nfrom scipy.ndimage.interpolation import shift\nfrom scipy.stats import multi"
  },
  {
    "path": "basicsr/data/paired_image_dataset.py",
    "chars": 4537,
    "preview": "from torch.utils import data as data\nfrom torchvision.transforms.functional import normalize\n\nfrom basicsr.data.data_uti"
  },
  {
    "path": "basicsr/data/prefetch_dataloader.py",
    "chars": 3131,
    "preview": "import queue as Queue\nimport threading\nimport torch\nfrom torch.utils.data import DataLoader\n\n\nclass PrefetchGenerator(th"
  },
  {
    "path": "basicsr/data/transforms.py",
    "chars": 5571,
    "preview": "import cv2\nimport random\n\n\ndef mod_crop(img, scale):\n    \"\"\"Mod crop images, used during testing.\n\n    Args:\n        img"
  },
  {
    "path": "basicsr/losses/__init__.py",
    "chars": 836,
    "preview": "from copy import deepcopy\n\nfrom basicsr.utils import get_root_logger\nfrom basicsr.utils.registry import LOSS_REGISTRY\nfr"
  },
  {
    "path": "basicsr/losses/loss_util.py",
    "chars": 2903,
    "preview": "import functools\nfrom torch.nn import functional as F\n\n\ndef reduce_loss(loss, reduction):\n    \"\"\"Reduce loss as specifie"
  },
  {
    "path": "basicsr/losses/losses.py",
    "chars": 16538,
    "preview": "import math\nimport lpips\nimport torch\nfrom torch import autograd as autograd\nfrom torch import nn as nn\nfrom torch.nn im"
  },
  {
    "path": "basicsr/metrics/__init__.py",
    "chars": 507,
    "preview": "from copy import deepcopy\n\nfrom basicsr.utils.registry import METRIC_REGISTRY\nfrom .psnr_ssim import calculate_psnr, cal"
  },
  {
    "path": "basicsr/metrics/metric_util.py",
    "chars": 1288,
    "preview": "import numpy as np\n\nfrom basicsr.utils.matlab_functions import bgr2ycbcr\n\n\ndef reorder_image(img, input_order='HWC'):\n  "
  },
  {
    "path": "basicsr/metrics/psnr_ssim.py",
    "chars": 4563,
    "preview": "import cv2\nimport numpy as np\n\nfrom basicsr.metrics.metric_util import reorder_image, to_y_channel\nfrom basicsr.utils.re"
  },
  {
    "path": "basicsr/models/__init__.py",
    "chars": 1014,
    "preview": "import importlib\nfrom copy import deepcopy\nfrom os import path as osp\n\nfrom basicsr.utils import get_root_logger, scandi"
  },
  {
    "path": "basicsr/models/base_model.py",
    "chars": 12482,
    "preview": "import logging\nimport os\nimport torch\nfrom collections import OrderedDict\nfrom copy import deepcopy\nfrom torch.nn.parall"
  },
  {
    "path": "basicsr/models/codeformer_idx_model.py",
    "chars": 8883,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/models/codeformer_joint_model.py",
    "chars": 14922,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\n\nfrom basicsr.archs i"
  },
  {
    "path": "basicsr/models/codeformer_model.py",
    "chars": 14263,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/models/lr_scheduler.py",
    "chars": 3956,
    "preview": "import math\nfrom collections import Counter\nfrom torch.optim.lr_scheduler import _LRScheduler\n\n\nclass MultiStepRestartLR"
  },
  {
    "path": "basicsr/models/sr_model.py",
    "chars": 8264,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/models/vqgan_model.py",
    "chars": 11645,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom tqdm import tqdm\n\nfrom basicsr.archs im"
  },
  {
    "path": "basicsr/ops/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "basicsr/ops/dcn/__init__.py",
    "chars": 306,
    "preview": "from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,\n       "
  },
  {
    "path": "basicsr/ops/dcn/deform_conv.py",
    "chars": 15574,
    "preview": "import math\nimport torch\nfrom torch import nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function imp"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_cuda.cpp",
    "chars": 28838,
    "preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu",
    "chars": 42622,
    "preview": "/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contribu"
  },
  {
    "path": "basicsr/ops/dcn/src/deform_conv_ext.cpp",
    "chars": 7492,
    "preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
  },
  {
    "path": "basicsr/ops/fused_act/__init__.py",
    "chars": 106,
    "preview": "from .fused_act import FusedLeakyReLU, fused_leaky_relu\n\n__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']\n"
  },
  {
    "path": "basicsr/ops/fused_act/fused_act.py",
    "chars": 2717,
    "preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501\n\nimport torch\nfrom"
  },
  {
    "path": "basicsr/ops/fused_act/src/fused_bias_act.cpp",
    "chars": 1092,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp\n#include <torch/extension.h>\n\n"
  },
  {
    "path": "basicsr/ops/fused_act/src/fused_bias_act_kernel.cu",
    "chars": 2874,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu\n// Copyright (c) 2019, N"
  },
  {
    "path": "basicsr/ops/upfirdn2d/__init__.py",
    "chars": 58,
    "preview": "from .upfirdn2d import upfirdn2d\n\n__all__ = ['upfirdn2d']\n"
  },
  {
    "path": "basicsr/ops/upfirdn2d/src/upfirdn2d.cpp",
    "chars": 1052,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp\n#include <torch/extension.h>\n\n\ntorc"
  },
  {
    "path": "basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu",
    "chars": 11803,
    "preview": "// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu\n// Copyright (c) 2019, NVIDIA"
  },
  {
    "path": "basicsr/ops/upfirdn2d/upfirdn2d.py",
    "chars": 5865,
    "preview": "# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py  # noqa:E501\n\nimport torch\nfro"
  },
  {
    "path": "basicsr/setup.py",
    "chars": 5205,
    "preview": "#!/usr/bin/env python\n\nfrom setuptools import find_packages, setup\n\nimport os\nimport subprocess\nimport sys\nimport time\nf"
  },
  {
    "path": "basicsr/train.py",
    "chars": 9399,
    "preview": "import argparse\nimport datetime\nimport logging\nimport math\nimport copy\nimport random\nimport time\nimport torch\nfrom os im"
  },
  {
    "path": "basicsr/utils/__init__.py",
    "chars": 774,
    "preview": "from .file_client import FileClient\nfrom .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img\nfrom"
  },
  {
    "path": "basicsr/utils/dist_util.py",
    "chars": 2608,
    "preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py  # noqa: E501\nimport functools\n"
  },
  {
    "path": "basicsr/utils/download_util.py",
    "chars": 3369,
    "preview": "import math\nimport os\nimport requests\nfrom torch.hub import download_url_to_file, get_dir\nfrom tqdm import tqdm\nfrom url"
  },
  {
    "path": "basicsr/utils/file_client.py",
    "chars": 6017,
    "preview": "# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py  # noqa: E501\nfrom abc import "
  },
  {
    "path": "basicsr/utils/img_util.py",
    "chars": 6138,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport torch\nfrom torchvision.utils import make_grid\n\n\ndef img2tenso"
  },
  {
    "path": "basicsr/utils/lmdb_util.py",
    "chars": 7105,
    "preview": "import cv2\nimport lmdb\nimport sys\nfrom multiprocessing import Pool\nfrom os import path as osp\nfrom tqdm import tqdm\n\n\nde"
  },
  {
    "path": "basicsr/utils/logger.py",
    "chars": 6318,
    "preview": "import datetime\nimport logging\nimport time\n\nfrom .dist_util import get_dist_info, master_only\n\ninitialized_logger = {}\n\n"
  },
  {
    "path": "basicsr/utils/matlab_functions.py",
    "chars": 13523,
    "preview": "import math\nimport numpy as np\nimport torch\n\n\ndef cubic(x):\n    \"\"\"cubic function used for calculate_weights_indices.\"\"\""
  },
  {
    "path": "basicsr/utils/misc.py",
    "chars": 5193,
    "preview": "import os\nimport re\nimport random\nimport time\nimport torch\nimport numpy as np\nfrom os import path as osp\n\nfrom .dist_uti"
  },
  {
    "path": "basicsr/utils/options.py",
    "chars": 3496,
    "preview": "import yaml\nimport time\nfrom collections import OrderedDict\nfrom os import path as osp\nfrom basicsr.utils.misc import ge"
  },
  {
    "path": "basicsr/utils/realesrgan_utils.py",
    "chars": 12264,
    "preview": "import cv2\nimport math\nimport numpy as np\nimport os\nimport queue\nimport threading\nimport torch\nfrom torch.nn import func"
  },
  {
    "path": "basicsr/utils/registry.py",
    "chars": 2185,
    "preview": "# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py  # noqa: E501\n\n\nclass "
  },
  {
    "path": "basicsr/utils/video_util.py",
    "chars": 4578,
    "preview": "'''\nThe code is modified from the Real-ESRGAN:\nhttps://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_v"
  },
  {
    "path": "docs/history_changelog.md",
    "chars": 1657,
    "preview": "# History of Changelog\n\n- **2023.04.19**: :whale: Training codes and config files are public available now.\n- **2023.04."
  },
  {
    "path": "docs/train.md",
    "chars": 2012,
    "preview": "# :milky_way: Training Procedures\n[English](train.md) **|** [简体中文](train_CN.md)\n## Preparing Dataset\n\n- Download trainin"
  },
  {
    "path": "docs/train_CN.md",
    "chars": 1508,
    "preview": "# :milky_way: 训练文档\n[English](train.md) **|** [简体中文](train_CN.md)\n\n## 准备数据集\n- 下载训练数据集: [FFHQ](https://github.com/NVlabs/f"
  },
  {
    "path": "facelib/detection/__init__.py",
    "chars": 4400,
    "preview": "import os\nimport torch\nfrom torch import nn\nfrom copy import deepcopy\n\nfrom facelib.utils import load_file_from_url\nfrom"
  },
  {
    "path": "facelib/detection/align_trans.py",
    "chars": 7941,
    "preview": "import cv2\nimport numpy as np\n\nfrom .matlab_cp2tform import get_similarity_transform_for_cv2\n\n# reference facial points,"
  },
  {
    "path": "facelib/detection/matlab_cp2tform.py",
    "chars": 8109,
    "preview": "import numpy as np\nfrom numpy.linalg import inv, lstsq\nfrom numpy.linalg import matrix_rank as rank\nfrom numpy.linalg im"
  },
  {
    "path": "facelib/detection/retinaface/retinaface.py",
    "chars": 13397,
    "preview": "import cv2\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom PIL import Image\nf"
  },
  {
    "path": "facelib/detection/retinaface/retinaface_net.py",
    "chars": 6281,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef conv_bn(inp, oup, stride=1, leaky=0):\n    retur"
  },
  {
    "path": "facelib/detection/retinaface/retinaface_utils.py",
    "chars": 16362,
    "preview": "import numpy as np\nimport torch\nimport torchvision\nfrom itertools import product as product\nfrom math import ceil\n\n\nclas"
  },
  {
    "path": "facelib/detection/yolov5face/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/face_detector.py",
    "chars": 5894,
    "preview": "import cv2\nimport copy\nimport re\nimport torch\nimport numpy as np\n\nfrom pathlib import Path\nfrom facelib.detection.yolov5"
  },
  {
    "path": "facelib/detection/yolov5face/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/models/common.py",
    "chars": 11561,
    "preview": "# This file contains modules common to various models\n\nimport math\n\nimport numpy as np\nimport torch\nfrom torch import nn"
  },
  {
    "path": "facelib/detection/yolov5face/models/experimental.py",
    "chars": 1723,
    "preview": "# # This file contains experimental modules\n\nimport numpy as np\nimport torch\nfrom torch import nn\n\nfrom facelib.detectio"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolo.py",
    "chars": 9730,
    "preview": "import math\nfrom copy import deepcopy\nfrom pathlib import Path\n\nimport torch\nimport yaml  # for torch hub\nfrom torch imp"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolov5l.yaml",
    "chars": 1344,
    "preview": "# parameters\nnc: 1  # number of classes\ndepth_multiple: 1.0  # model depth multiple\nwidth_multiple: 1.0  # layer channel"
  },
  {
    "path": "facelib/detection/yolov5face/models/yolov5n.yaml",
    "chars": 1335,
    "preview": "# parameters\nnc: 1  # number of classes\ndepth_multiple: 1.0  # model depth multiple\nwidth_multiple: 1.0  # layer channel"
  },
  {
    "path": "facelib/detection/yolov5face/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "facelib/detection/yolov5face/utils/autoanchor.py",
    "chars": 460,
    "preview": "# Auto-anchor utils\n\n\ndef check_anchor_order(m):\n    # Check anchor order against stride order for YOLOv5 Detect() modul"
  },
  {
    "path": "facelib/detection/yolov5face/utils/datasets.py",
    "chars": 1515,
    "preview": "import cv2\nimport numpy as np\n\n\ndef letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=Fa"
  },
  {
    "path": "facelib/detection/yolov5face/utils/extract_ckpt.py",
    "chars": 237,
    "preview": "import torch\nimport sys\nsys.path.insert(0,'./facelib/detection/yolov5face')\nmodel = torch.load('facelib/detection/yolov5"
  },
  {
    "path": "facelib/detection/yolov5face/utils/general.py",
    "chars": 10348,
    "preview": "import math\nimport time\n\nimport numpy as np\nimport torch\nimport torchvision\n\n\ndef check_img_size(img_size, s=32):\n    # "
  },
  {
    "path": "facelib/detection/yolov5face/utils/torch_utils.py",
    "chars": 1375,
    "preview": "import torch\nfrom torch import nn\n\n\ndef fuse_conv_and_bn(conv, bn):\n    # Fuse convolution and batchnorm layers https://"
  },
  {
    "path": "facelib/parsing/__init__.py",
    "chars": 961,
    "preview": "import torch\n\nfrom facelib.utils import load_file_from_url\nfrom .bisenet import BiSeNet\nfrom .parsenet import ParseNet\n\n"
  },
  {
    "path": "facelib/parsing/bisenet.py",
    "chars": 5190,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .resnet import ResNet18\n\n\nclass ConvBNReLU(nn.M"
  },
  {
    "path": "facelib/parsing/parsenet.py",
    "chars": 6477,
    "preview": "\"\"\"Modified from https://github.com/chaofengc/PSFRGAN\n\"\"\"\nimport numpy as np\nimport torch.nn as nn\nfrom torch.nn import "
  },
  {
    "path": "facelib/parsing/resnet.py",
    "chars": 2357,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef conv3x3(in_planes, out_planes, stride=1):\n    \"\"\"3x3 convolu"
  },
  {
    "path": "facelib/utils/__init__.py",
    "chars": 389,
    "preview": "from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back\nfrom .misc "
  },
  {
    "path": "facelib/utils/face_restoration_helper.py",
    "chars": 25240,
    "preview": "import cv2\nimport numpy as np\nimport os\nimport torch\nfrom torchvision.transforms.functional import normalize\n\nfrom facel"
  },
  {
    "path": "facelib/utils/face_utils.py",
    "chars": 10061,
    "preview": "import cv2\nimport numpy as np\nimport torch\n\n\ndef compute_increased_bbox(bbox, increase_area, preserve_aspect=True):\n    "
  },
  {
    "path": "facelib/utils/misc.py",
    "chars": 7157,
    "preview": "import cv2\nimport os\nimport os.path as osp\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom torch.hub import d"
  },
  {
    "path": "inference_codeformer.py",
    "chars": 12691,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nfrom torchvision.transforms.functional import normalize\nfr"
  },
  {
    "path": "inference_colorization.py",
    "chars": 4125,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nfrom torchvision.transforms.functional import normalize\nfr"
  },
  {
    "path": "inference_inpainting.py",
    "chars": 4364,
    "preview": "import os\nimport cv2\nimport argparse\nimport glob\nimport torch\nfrom torchvision.transforms.functional import normalize\nfr"
  },
  {
    "path": "options/CodeFormer_colorization.yml",
    "chars": 3002,
    "preview": "# general settings\nname: CodeFormer_colorization\nmodel_type: CodeFormerIdxModel\nnum_gpu: 8\nmanual_seed: 0\n\n# dataset and"
  },
  {
    "path": "options/CodeFormer_inpainting.yml",
    "chars": 3095,
    "preview": "# general settings\nname: CodeFormer_inpainting\nmodel_type: CodeFormerModel\nnum_gpu: 4\nmanual_seed: 0\n\n# dataset and data"
  },
  {
    "path": "options/CodeFormer_stage2.yml",
    "chars": 3033,
    "preview": "# general settings\nname: CodeFormer_stage2\nmodel_type: CodeFormerIdxModel\nnum_gpu: 8\nmanual_seed: 0\n\n# dataset and data "
  },
  {
    "path": "options/CodeFormer_stage3.yml",
    "chars": 3614,
    "preview": "# general settings\nname: CodeFormer_stage3\nmodel_type: CodeFormerJointModel\nnum_gpu: 8\nmanual_seed: 0\n\n# dataset and dat"
  },
  {
    "path": "options/VQGAN_512_ds32_nearest_stage1.yml",
    "chars": 2564,
    "preview": "# general settings\nname: VQGAN-512-ds32-nearest-stage1\nmodel_type: VQGANModel\nnum_gpu: 8\nmanual_seed: 0\n\n# dataset and d"
  },
  {
    "path": "requirements.txt",
    "chars": 194,
    "preview": "addict\nfuture\nlmdb\nnumpy\nopencv-python\nPillow\npyyaml\nrequests\nscikit-image\nscipy\ntb-nightly\ntorch>=1.7.1\ntorchvision\ntqd"
  },
  {
    "path": "scripts/crop_align_face.py",
    "chars": 7507,
    "preview": "\"\"\"\nbrief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)\nauthor: lzhbrian (https://lzhbrian.m"
  },
  {
    "path": "scripts/download_pretrained_models.py",
    "chars": 2415,
    "preview": "import argparse\nimport os\nfrom os import path as osp\n\nfrom basicsr.utils.download_util import load_file_from_url\n\n\ndef d"
  },
  {
    "path": "scripts/download_pretrained_models_from_gdrive.py",
    "chars": 2307,
    "preview": "import argparse\nimport os\nfrom os import path as osp\n\n# from basicsr.utils.download_util import download_file_from_googl"
  },
  {
    "path": "scripts/generate_latent_gt.py",
    "chars": 2744,
    "preview": "import argparse\nimport glob\nimport numpy as np\nimport os\nimport cv2\nimport torch\nfrom torchvision.transforms.functional "
  },
  {
    "path": "scripts/inference_vqgan.py",
    "chars": 2273,
    "preview": "import argparse\nimport glob\nimport numpy as np\nimport os\nimport cv2\nimport torch\nfrom torchvision.transforms.functional "
  },
  {
    "path": "web-demos/hugging_face/app.py",
    "chars": 11516,
    "preview": "\"\"\"\nThis file is used for deploying hugging face demo:\nhttps://huggingface.co/spaces/sczhou/CodeFormer\n\"\"\"\n\nimport sys\ns"
  },
  {
    "path": "web-demos/replicate/cog.yaml",
    "chars": 745,
    "preview": "\"\"\"\nThis file is used for deploying replicate demo:\nhttps://replicate.com/sczhou/codeformer\n\"\"\"\n\nbuild:\n  gpu: true\n  cu"
  },
  {
    "path": "web-demos/replicate/predict.py",
    "chars": 6678,
    "preview": "\"\"\"\nThis file is used for deploying replicate demo:\nhttps://replicate.com/sczhou/codeformer\nrunning: cog predict -i imag"
  },
  {
    "path": "weights/CodeFormer/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "weights/README.md",
    "chars": 64,
    "preview": "# Weights\n\nPut the downloaded pre-trained models to this folder."
  },
  {
    "path": "weights/facelib/.gitkeep",
    "chars": 0,
    "preview": ""
  }
]

About this extraction

This page contains the full source code of the sczhou/CodeFormer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 115 files (692.8 KB), approximately 189.9k tokens, and a symbol index with 722 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!