Repository: fudan-generative-vision/DicFace
Branch: main
Commit: df08bb76d051
Files: 100
Total size: 687.8 KB
Directory structure:
gitextract_tgnuarg3/
├── .gitignore
├── README.md
├── basicsr/
│ ├── VERSION
│ ├── __init__.py
│ ├── archs/
│ │ ├── __init__.py
│ │ ├── arcface_arch.py
│ │ ├── arch_util.py
│ │ ├── dir_dist_codeformer_multiscale_arch.py
│ │ ├── rrdbnet_arch.py
│ │ ├── vgg_arch.py
│ │ └── vqgan_arch.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── color_dataset.py
│ │ ├── data_sampler.py
│ │ ├── data_util.py
│ │ ├── degradations.py
│ │ ├── gaussian_kernels.py
│ │ ├── inpainting_dataset.py
│ │ ├── paired_image_dataset.py
│ │ ├── prefetch_dataloader.py
│ │ ├── transforms.py
│ │ └── vfhq_dataset.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── metric_util.py
│ │ └── psnr_ssim.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── codeformer_dirichlet_video_model.py
│ │ ├── lr_scheduler.py
│ │ ├── sr_model.py
│ │ └── vqgan_model.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── dcn/
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ └── src/
│ │ │ ├── deform_conv_cuda.cpp
│ │ │ ├── deform_conv_cuda_kernel.cu
│ │ │ └── deform_conv_ext.cpp
│ │ ├── fused_act/
│ │ │ ├── __init__.py
│ │ │ ├── fused_act.py
│ │ │ └── src/
│ │ │ ├── fused_bias_act.cpp
│ │ │ └── fused_bias_act_kernel.cu
│ │ └── upfirdn2d/
│ │ ├── __init__.py
│ │ ├── src/
│ │ │ ├── upfirdn2d.cpp
│ │ │ └── upfirdn2d_kernel.cu
│ │ └── upfirdn2d.py
│ ├── setup.py
│ ├── train.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── dist_util.py
│ │ ├── download_util.py
│ │ ├── file_client.py
│ │ ├── img_util.py
│ │ ├── lmdb_util.py
│ │ ├── logger.py
│ │ ├── matlab_functions.py
│ │ ├── misc.py
│ │ ├── options.py
│ │ ├── realesrgan_utils.py
│ │ ├── registry.py
│ │ └── video_util.py
│ └── version.py
├── facelib/
│ ├── detection/
│ │ ├── __init__.py
│ │ ├── align_trans.py
│ │ ├── matlab_cp2tform.py
│ │ ├── retinaface/
│ │ │ ├── retinaface.py
│ │ │ ├── retinaface_net.py
│ │ │ └── retinaface_utils.py
│ │ └── yolov5face/
│ │ ├── __init__.py
│ │ ├── face_detector.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── experimental.py
│ │ │ ├── yolo.py
│ │ │ ├── yolov5l.yaml
│ │ │ └── yolov5n.yaml
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── autoanchor.py
│ │ ├── datasets.py
│ │ ├── extract_ckpt.py
│ │ ├── general.py
│ │ └── torch_utils.py
│ ├── parsing/
│ │ ├── __init__.py
│ │ ├── bisenet.py
│ │ ├── parsenet.py
│ │ └── resnet.py
│ └── utils/
│ ├── __init__.py
│ ├── face_restoration_helper.py
│ ├── face_utils.py
│ └── misc.py
├── options/
│ ├── clip5_bs2_512_align_nofix_multiscale.yaml
│ ├── clip5_bs2_512_align_nofix_multiscale_color.yaml
│ └── clip5_bs2_512_align_nofix_multiscale_inpaint.yaml
├── requirements.txt
├── scripts/
│ ├── inference.py
│ ├── inference_color_and_inpainting.py
│ └── warp_images.py
└── train.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# 忽略操作系统生成的文件
.DS_Store
Thumbs.db
# 忽略编译生成的文件
*.class
*.exe
*.o
*.so
.eggs/
*.egg-info/
# 忽略包管理工具生成的文件
node_modules/
vendor/
# 忽略 Python 缓存目录
__pycache__/
# 忽略日志文件
*.log
# 忽略环境配置文件
.env
# 忽略IDE/编辑器配置文件
.idea/
.vscode/
# test folder
test*/
# ckpts
ckpts/
================================================
FILE: README.md
================================================
DicFace: Dirichlet-Constrained Variational Codebook Learning for Temporally Coherent Video Face Restoration
1Fudan University
2Alibaba Group
3Nanjing University
## 🖼️ Showcase
### Blind Face Restoration
### Face Inpainting
### Face Colorization
### 🐾 Wild Data Examples
## 📰 News
- **`2025/07/25`**: 🎉🎉🎉 Our paper has been accepted to [ICCV 2025](https://iccv.thecvf.com/Conferences/2025)and selected as a highlight.
- **`2025/06/26`**: 🎉🎉🎉 Our paper has been accepted to [ICCV 2025](https://iccv.thecvf.com/Conferences/2025).
- **`2025/06/25`**: Release our test data on huggingface [repo](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset).
- **`2025/06/23`**: Release our pretrained model on huggingface [repo](https://huggingface.co/fudan-generative-ai/DicFace).
- **`2025/06/17`**: Paper submitted on Arixiv. [paper](https://arxiv.org/abs/2506.13355)
- **`2025/06/16`**: 🎉🎉🎉 Release inference scripts
## 📅️ Roadmap
| Status | Milestone | ETA |
| :----: | :----------------------------------------------------------------------------------------------------- | :--------: |
| ✅ | **[Inference Code release](https://github.com/fudan-generative-vision/DicFace)** | 2025-6-16 |
| ✅ | **[Model Weight release, baidu-link](https://pan.baidu.com/s/1VTNbdtZDvgY0163a1T8ITw?pwd=dicf)** |2025-6-16 |
| ✅ | **[Paper submitted on Arixiv](https://arxiv.org/abs/2506.13355)** | 2025-6-17 |
| ✅ | **[Test data release](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset)** | 2025-6-25 |
| ✅ | **[Training Code release]()** | 2025-6-26 |
## ⚙️ Installation
- System requirement: PyTorch version >=2.4.1, python == 3.10
- Tested on GPUs: A800, python version == 3.10, PyTorch version == 2.4.1, cuda version == 12.1
Download the codes:
```bash
git clone https://github.com/fudan-generative-vision/DicFace
cd DicFace
```
Create conda environment:
```bash
conda create -n DicFace python=3.10
conda activate DicFace
```
Install PyTorch
```bash
conda install pytorch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 pytorch-cuda=12.1 -c pytorch -c nvidia
```
Install packages with `pip`
```bash
pip install -r requirements.txt
python basicsr/setup.py develop
conda install -c conda-forge dlib
```
### 📥 Download Pretrained Models
The pre-trained weights have been uploaded to Baidu Netdisk. Please download them from the [link](https://pan.baidu.com/s/1VTNbdtZDvgY0163a1T8ITw?pwd=dicf)
Now you can easily get all pretrained models required by inference from our HuggingFace [repo](https://huggingface.co/fudan-generative-ai/DicFace_model).
**File Structure of Pretrained Models**
The downloaded .ckpts directory contains the following pre-trained models:
```
.ckpts
|-- CodeFormer # CodeFormer-related models
| |-- bfr_100k.pth # Blind Face Restoration model
| |-- color_100k.pth # Color Restoration model
| |-- codeformer.pth # codeformer model
| |-- vqgan_discriminator.pth # vqgan_discriminator model
| `-- inpainting_100k.pth # Image Inpainting model
|-- dlib # dlib face-related models
| |-- mmod_human_face_detector.dat # Human face detector
| `-- shape_predictor_5_face_landmarks.dat # 5-point face landmark predictor
|-- facelib # Face processing library models
| |-- detection_Resnet50_Final.pth # ResNet50 face detector
| |-- detection_mobilenet0.25_Final.pth # MobileNet0.25 face detector
| |-- parsing_parsenet.pth # Face parsing model
| |-- yolov5l-face.pth # YOLOv5l face detection model
| `-- yolov5n-face.pth # YOLOv5n face detection model
|-- realesrgan # Real-ESRGAN super-resolution model
| `-- RealESRGAN_x2plus.pth # 2x super-resolution enhancement model
`-- vgg # VGG feature extraction model
`-- vgg.pth # VGG network pre-trained weights
```
### 🎮 Run Inference
#### for blind face restoration
```bash
python scripts/inference.py \
-i /path/to/video \
-o /path/to/output_folder \
--max_length 10 \
--save_video_fps 24 \
--ckpt_path /bfr/bfr_weight.pth \
--bg_upsampler realesrgan \
--save_video
# or your videos has been aligned
python scripts/inference.py \
-i /path/to/video \
-o /path/to/output_folder \
--max_length 10 \
--save_video_fps 24 \
--ckpt_path /bfr/bfr_weight.pth \
--save_video \
--has_aligned
```
#### for colorization & inpainting task
**The current colorization & inpainting tasks only supports input of aligned faces. If a non-aligned face is input, it may lead to unsatisfactory final results.**
``` bash
# for colorization task
python scripts/inference_color_and_inpainting.py \
-i /path/to/video_warped \
-o /path/to/output_folder \
--max_length 10 \
--save_video_fps 24 \
--ckpt_path /colorization/colorization_weight.pth \
--bg_upsampler realesrgan \
--save_video \
--has_aligned
# for inpainting task
python scripts/inference_color_and_inpainting.py \
-i /path/to/video_warped \
-o /path/to/output_folder \
--max_length 10 \
--save_video_fps 24 \
--ckpt_path /inpainting/inpainting_weight.pth \
--bg_upsampler realesrgan \
--save_video \
--has_aligned
```
## Test Data
Our test data can be accessed via the following links:
- Baidu Netdisk: [https://pan.baidu.com/s/1zMp3fnf6LvlRT9CAoL1OUw](https://pan.baidu.com/s/1zMp3fnf6LvlRT9CAoL1OUw) (Password: `drhh`)
- Hugging Face Dataset: [https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset](https://huggingface.co/datasets/fudan-generative-ai/DicFace-test_dataset)
### Directory Structure
The downloaded `test_data_set` directory contains the following folders:
```
./test_data
├── LR_Blind # Blind face restoration test image folders
│ ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│ ├── ... # Additional test image folders
│ └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── TEST_DATA # Ground-truth (GT) image folders
│ ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│ ├── ...
│ └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── vfhq_test_color_input # Colorization test image folders
│ ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│ ├── ...
│ └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
├── vfhq_test_inpaint_input_512 # Inpainting test image folders (512x512)
│ ├── Clip+_HebIzK_LP4+P2+C1+F16589-16715
│ ├── ...
│ └── Clip+y5OFsRIRkwc+P0+C0+F9797-9938
│
└── vfhq_test_landmarks # Facial landmark files for warping operations
```
### Usage
To process the test data, use the `warp_images.py` script:
```shell
python scripts/warp_images.py \
-i input_test_data_folder \
-o vfhq_test_inpaint_input_512_warped \
-l /path/to/test_data_folder/vfhq_test_landmarks
```
After warping the test data, you can use the inference scripts to generate results for the test dataset.
### Training
#### Training Data
We utilize the VFHQ dataset for both training and testing. The test data is specifically sourced from VFHQ-Test. For more details, please refer to the official project page: [VFHQ](https://liangbinxie.github.io/projects/vfhq/).
### Prerequisites for Training
Before initiating the training process, ensure that you have completed the following steps:
1. **Image Size Requirement**:
- All input images must be resized to 512 x 512 pixels.
2. **Download Necessary Files**:
- Obtain the metadata files and facial landmark information from our Hugging Face repository. [TBD(not ready)]
3. **Configure YAML Files**:
- Edit the configuration file located at `options/xxx.yaml` to specify your training parameters and dataset paths.
### Initiate Training
Once the prerequisites are met, start the training process by executing the following command:
```bash
bash train.sh
```
This script will initiate the training procedure using the settings defined in your YAML configuration file.
## 🤗 Acknowledgements
This project is open sourced under NTU S-Lab License 1.0. Redistribution and use should follow this license. The code framework is mainly modified from [CodeFormer](https://github.com/sczhou/CodeFormer). Please refer to the original repo for more usage and documents.
## 📝 Citation
If you find our work useful for your research, please consider citing the paper:
```
@misc{chen2025dicfacedirichletconstrainedvariationalcodebook,
title={DicFace: Dirichlet-Constrained Variational Codebook Learning for Temporally Coherent Video Face Restoration},
author={Yan Chen and Hanlin Shang and Ce Liu and Yuxuan Chen and Hui Li and Weihao Yuan and Hao Zhu and Zilong Dong and Siyu Zhu},
year={2025},
eprint={2506.13355},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2506.13355},
}
```
================================================
FILE: basicsr/VERSION
================================================
1.3.2
================================================
FILE: basicsr/__init__.py
================================================
# https://github.com/xinntao/BasicSR
# flake8: noqa
from .archs import *
from .data import *
from .losses import *
from .metrics import *
from .models import *
from .ops import *
from .train import *
from .utils import *
from .version import __gitsha__, __version__
================================================
FILE: basicsr/archs/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import ARCH_REGISTRY
__all__ = ['build_network']
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
def build_network(opt):
opt = deepcopy(opt)
network_type = opt.pop('type')
net = ARCH_REGISTRY.get(network_type)(**opt)
logger = get_root_logger()
logger.info(f'Network [{net.__class__.__name__}] is created.')
return net
================================================
FILE: basicsr/archs/arcface_arch.py
================================================
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
def conv3x3(inplanes, outplanes, stride=1):
"""A simple wrapper for 3x3 convolution with padding.
Args:
inplanes (int): Channel number of inputs.
outplanes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
"""
return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
"""Basic residual block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class IRBlock(nn.Module):
"""Improved residual block (IR Block) used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
expansion = 1 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
super(IRBlock, self).__init__()
self.bn0 = nn.BatchNorm2d(inplanes)
self.conv1 = conv3x3(inplanes, inplanes)
self.bn1 = nn.BatchNorm2d(inplanes)
self.prelu = nn.PReLU()
self.conv2 = conv3x3(inplanes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.use_se = use_se
if self.use_se:
self.se = SEBlock(planes)
def forward(self, x):
residual = x
out = self.bn0(x)
out = self.conv1(out)
out = self.bn1(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.prelu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck block used in the ResNetArcFace architecture.
Args:
inplanes (int): Channel number of inputs.
planes (int): Channel number of outputs.
stride (int): Stride in convolution. Default: 1.
downsample (nn.Module): The downsample module. Default: None.
"""
expansion = 4 # output channel expansion ratio
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class SEBlock(nn.Module):
"""The squeeze-and-excitation block (SEBlock) used in the IRBlock.
Args:
channel (int): Channel number of inputs.
reduction (int): Channel reduction ration. Default: 16.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
nn.Sigmoid())
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y
@ARCH_REGISTRY.register()
class ResNetArcFace(nn.Module):
"""ArcFace with ResNet architectures.
Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
Args:
block (str): Block used in the ArcFace architecture.
layers (tuple(int)): Block numbers in each layer.
use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
"""
def __init__(self, block, layers, use_se=True):
if block == 'IRBlock':
block = IRBlock
self.inplanes = 64
self.use_se = use_se
super(ResNetArcFace, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.prelu = nn.PReLU()
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.bn4 = nn.BatchNorm2d(512)
self.dropout = nn.Dropout()
self.fc5 = nn.Linear(512 * 8 * 8, 512)
self.bn5 = nn.BatchNorm1d(512)
# initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
self.inplanes = planes
for _ in range(1, num_blocks):
layers.append(block(self.inplanes, planes, use_se=self.use_se))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn4(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = self.fc5(x)
x = self.bn5(x)
return x
================================================
FILE: basicsr/archs/arch_util.py
================================================
import collections.abc
import math
import torch
import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
from torch.nn.modules.batchnorm import _BatchNorm
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
from basicsr.utils import get_root_logger
@torch.no_grad()
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
"""Initialize network weights.
Args:
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
scale (float): Scale initialized weights, especially for residual
blocks. Default: 1.
bias_fill (float): The value to fill bias. Default: 0
kwargs (dict): Other arguments for initialization function.
"""
if not isinstance(module_list, list):
module_list = [module_list]
for module in module_list:
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, **kwargs)
m.weight.data *= scale
if m.bias is not None:
m.bias.data.fill_(bias_fill)
elif isinstance(m, _BatchNorm):
init.constant_(m.weight, 1)
if m.bias is not None:
m.bias.data.fill_(bias_fill)
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block.
num_basic_block (int): number of blocks.
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
class ResidualBlockNoBN(nn.Module):
"""Residual block without BN.
It has a style of:
---Conv-ReLU-Conv-+-
|________________|
Args:
num_feat (int): Channel number of intermediate features.
Default: 64.
res_scale (float): Residual scale. Default: 1.
pytorch_init (bool): If set to True, use pytorch default init,
otherwise, use default_init_weights. Default: False.
"""
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
super(ResidualBlockNoBN, self).__init__()
self.res_scale = res_scale
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.relu = nn.ReLU(inplace=True)
if not pytorch_init:
default_init_weights([self.conv1, self.conv2], 0.1)
def forward(self, x):
identity = x
out = self.conv2(self.relu(self.conv1(x)))
return identity + out * self.res_scale
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
"""Warp an image or feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
padding_mode (str): 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Before pytorch 1.3, the default value is
align_corners=True. After pytorch 1.3, the default value is
align_corners=False. Here, we use the True as default.
Returns:
Tensor: Warped image or feature map.
"""
assert x.size()[-2:] == flow.size()[1:3]
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
grid.requires_grad = False
vgrid = grid + flow
# scale grid to [-1,1]
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
# TODO, what if align_corners=False
return output
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
"""Resize a flow according to ratio or shape.
Args:
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
size_type (str): 'ratio' or 'shape'.
sizes (list[int | float]): the ratio for resizing or the final output
shape.
1) The order of ratio should be [ratio_h, ratio_w]. For
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
ratio > 1.0).
2) The order of output_size should be [out_h, out_w].
interp_mode (str): The mode of interpolation for resizing.
Default: 'bilinear'.
align_corners (bool): Whether align corners. Default: False.
Returns:
Tensor: Resized flow.
"""
_, _, flow_h, flow_w = flow.size()
if size_type == 'ratio':
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
elif size_type == 'shape':
output_h, output_w = sizes[0], sizes[1]
else:
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
input_flow = flow.clone()
ratio_h = output_h / flow_h
ratio_w = output_w / flow_w
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
class DCNv2Pack(ModulatedDeformConvPack):
"""Modulated deformable conv for deformable alignment.
Different from the official DCNv2Pack, which generates offsets and masks
from the preceding features, this DCNv2Pack takes another different
features to generate offsets and masks.
Ref:
Delving Deep into Deformable Alignment in Video Super-Resolution.
"""
def forward(self, x, feat):
out = self.conv_offset(feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
offset_absmean = torch.mean(torch.abs(offset))
if offset_absmean > 50:
logger = get_root_logger()
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
self.dilation, mask)
else:
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.deformable_groups)
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
low = norm_cdf((a - mean) / std)
up = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [low, up], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * low - 1, 2 * up - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
# From PyTorch
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
================================================
FILE: basicsr/archs/dir_dist_codeformer_multiscale_arch.py
================================================
import math
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from typing import Optional, List
from basicsr.archs.vqgan_arch import *
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
import torch.distributions as dist
from einops import rearrange
def calc_mean_std(feat, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.view(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(b, c, 1, 1)
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat, style_feat):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, x, mask=None):
if mask is None:
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransformerSALayer(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(
self,
tgt,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q,
k,
value=tgt2,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
return tgt
class TransformerSALayerTemporal(nn.Module):
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
super().__init__()
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
# Implementation of Feedforward model - MLP
self.linear1 = nn.Linear(embed_dim, dim_mlp)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_mlp, embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self,
tgt,
frame_length=10,
batch_size=1,
tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt = rearrange(tgt, "d (b t) c -> t (b d) c", t=frame_length)
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(q,
k,
value=tgt2,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
# ffn
tgt2 = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout2(tgt2)
# reshape
tgt = rearrange(tgt, "t (b d) c -> d (b t) c", b=batch_size)
return tgt
class Fuse_sft_block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.encode_enc = ResBlock(2*in_ch, out_ch)
self.scale = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
self.shift = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
def forward(self, enc_feat, dec_feat, w=1):
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
scale = self.scale(enc_feat)
shift = self.shift(enc_feat)
residual = w * (dec_feat * scale + shift)
out = dec_feat + residual
return out
class ExpModule(nn.Module):
def forward(self, x):
return torch.exp(x)
class MultiScaleFuse(nn.Module):
def __init__(self):
super(MultiScaleFuse, self).__init__()
self.s64_conv = nn.Conv2d(in_channels=256*16, out_channels=256, kernel_size=1)
self.s32_conv = nn.Conv2d(in_channels=256*4, out_channels=256, kernel_size=1)
self.s16_conv = nn.Conv2d(in_channels=256*1, out_channels=256, kernel_size=1)
self.out = nn.Conv2d(in_channels=256*3, out_channels=256, kernel_size=3, stride=1, padding=1)
def forward(self, s64, s32, s16):
feat_64 = rearrange(s64, "bt c (h h1) (w w1) -> bt (c h1 w1) h w", h1=4, w1=4)
feat_64 = self.s64_conv(feat_64)
feat_32 = rearrange(s32, "bt c (h h1) (w w1) -> bt (c h1 w1) h w", h1=2, w1=2)
feat_32 = self.s32_conv(feat_32)
feat_16 = self.s16_conv(s16)
out = self.out(torch.concat([feat_64, feat_32, feat_16], dim=1))
return out
@ARCH_REGISTRY.register()
class TemporalCodeFormerDirDistMultiScale(VQAutoEncoder):
def __init__(self,
dim_embed=512,
n_head=8,
n_layers=9,
codebook_size=1024,
latent_size=256,
connect_list=['32', '64', '128', '256'],
fix_modules=['quantize','generator'],
vqgan_path=None,
frame_length=10,
new_codebook_size=None):
super(TemporalCodeFormerDirDistMultiScale, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest', 2, [16], codebook_size)
if vqgan_path is not None:
self.load_state_dict(
torch.load(vqgan_path, map_location='cpu')['params_ema'])
self.frame_length = frame_length
self.connect_list = connect_list
self.n_layers = n_layers
self.dim_embed = dim_embed
self.dim_mlp = dim_embed * 2
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embed))
self.position_emb_temporal = nn.Parameter(torch.zeros(self.frame_length, self.dim_embed))
self.feat_emb = nn.Linear(256, self.dim_embed)
self.codebook_size = codebook_size
self.new_codebook_size = None
if new_codebook_size is not None:
self.new_codebook_size = new_codebook_size
self.codebook_size += new_codebook_size
self.new_codebook = nn.Parameter(torch.normal(mean=0, std=0.75, size=(new_codebook_size, 256)))
self.new_codebook.requires_grad = True
self.multiscale = MultiScaleFuse()
# transformer in Space
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embed,
nhead=n_head,
dim_mlp=self.dim_mlp,
dropout=0.1)
for _ in range(self.n_layers)])
# transformer in Temporal
self.dir_dist_layers = nn.Sequential(*[TransformerSALayerTemporal(embed_dim=dim_embed,
nhead=n_head,
dim_mlp=self.dim_mlp,
dropout=0.1)
for _ in range(self.n_layers)])
# logits_predict head
self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embed),
nn.Linear(dim_embed, self.codebook_size, bias=False),
)
self.channels = {
'16': 512,
'32': 256,
'64': 256,
'128': 128,
'256': 128,
'512': 64,
}
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
# fuse_convs_dict
self.fuse_convs_dict = nn.ModuleDict()
for f_size in self.connect_list:
in_ch = self.channels[f_size]
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
self.softplus_layer = nn.Softplus()
self.position_emb.requires_grad = False
print("Module: position_emb_spatial Frozen!")
if fix_modules is not None:
print(fix_modules, "frozen!")
for module in fix_modules:
for param_name, param in getattr(self, module).named_parameters():
if "conv3d" in param_name:
param.requires_grad = True
else:
# print(f"Module: {module}, Parameter name: {param_name} Frozen!")
param.requires_grad = False
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
# ################### Encoder #####################
enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks):
x = block(x)
if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone()
lq_feat = self.multiscale(enc_feat_dict['64'], enc_feat_dict['32'], x)
bt, c, h, width = lq_feat.shape
b = bt // self.frame_length
t = self.frame_length
# ################# Spatial & Temporal Transformers ###################
spatial_pos_emb = self.position_emb.unsqueeze(1).repeat(1, bt, 1)
temporal_pos_emb = self.position_emb_temporal.unsqueeze(1).repeat(1, b*h*width, 1)
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
query_emb = feat_emb
for layer_space, layer_temporal in zip(self.ft_layers, self.dir_dist_layers):
query_emb = layer_space(query_emb, query_pos=spatial_pos_emb)
query_emb = layer_temporal(query_emb, query_pos=temporal_pos_emb, frame_length=t, batch_size=b)
alpha = self.idx_pred_layer(query_emb)
alpha = alpha.permute(1, 0, 2)
alpha = self.softplus_layer(alpha) + 1e-2
dirichlet_dist = dist.Dirichlet(alpha)
parameters = dirichlet_dist.rsample()
parameters_reshaped = parameters.reshape(-1, self.codebook_size)
if self.new_codebook_size is not None:
quant_feat = torch.matmul(parameters_reshaped[:, :-self.new_codebook_size], self.quantize.embedding.weight) + \
torch.matmul(parameters_reshaped[:, -self.new_codebook_size:], self.new_codebook)
else:
quant_feat = torch.matmul(parameters_reshaped, self.quantize.embedding.weight)
quant_feat = rearrange(quant_feat, "(b t h w) c -> (b t) c h w", b=b, t=t, h=h, w=width)
if adain:
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
# ################## Generator ####################
x = quant_feat
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks):
x = block(x)
if i in fuse_list:
f_size = str(x.shape[-1])
if w > 0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x
return out, lq_feat, alpha + 1e-6
================================================
FILE: basicsr/archs/rrdbnet_arch.py
================================================
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
class ResidualDenseBlock(nn.Module):
"""Residual Dense Block.
Used in RRDB block in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat=64, num_grow_ch=32):
super(ResidualDenseBlock, self).__init__()
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
# Emperically, we use 0.2 to scale the residual for better performance
return x5 * 0.2 + x
class RRDB(nn.Module):
"""Residual in Residual Dense Block.
Used in RRDB-Net in ESRGAN.
Args:
num_feat (int): Channel number of intermediate features.
num_grow_ch (int): Channels for each growth.
"""
def __init__(self, num_feat, num_grow_ch=32):
super(RRDB, self).__init__()
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
def forward(self, x):
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
# Emperically, we use 0.2 to scale the residual for better performance
return out * 0.2 + x
@ARCH_REGISTRY.register()
class RRDBNet(nn.Module):
"""Networks consisting of Residual in Residual Dense Block, which is used
in ESRGAN.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
We extend ESRGAN for scale x2 and scale x1.
Note: This is one option for scale 1, scale 2 in RRDBNet.
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
Args:
num_in_ch (int): Channel number of inputs.
num_out_ch (int): Channel number of outputs.
num_feat (int): Channel number of intermediate features.
Default: 64
num_block (int): Block number in the trunk network. Defaults: 23
num_grow_ch (int): Channels for each growth. Default: 32.
"""
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
# upsample
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
if self.scale == 2:
feat = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
feat = pixel_unshuffle(x, scale=4)
else:
feat = x
feat = self.conv_first(feat)
body_feat = self.conv_body(self.body(feat))
feat = feat + body_feat
# upsample
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
================================================
FILE: basicsr/archs/vgg_arch.py
================================================
import os
import torch
from collections import OrderedDict
from torch import nn as nn
from torchvision.models import vgg as vgg
from basicsr.utils.registry import ARCH_REGISTRY
VGG_PRETRAIN_PATH = './ckpts/vgg/vgg16-397923af.pth'
NAMES = {
'vgg11': [
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
'pool5'
],
'vgg13': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
],
'vgg16': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
'pool5'
],
'vgg19': [
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
]
}
def insert_bn(names):
"""Insert bn layer after each conv.
Args:
names (list): The list of layer names.
Returns:
list: The list of layer names with bn layers.
"""
names_bn = []
for name in names:
names_bn.append(name)
if 'conv' in name:
position = name.replace('conv', '')
names_bn.append('bn' + position)
return names_bn
@ARCH_REGISTRY.register()
class VGGFeatureExtractor(nn.Module):
"""VGG network for feature extraction.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): Forward function returns the corresponding
features according to the layer_name_list.
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image. Importantly,
the input feature must in the range [0, 1]. Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
requires_grad (bool): If true, the parameters of VGG network will be
optimized. Default: False.
remove_pooling (bool): If true, the max pooling operations in VGG net
will be removed. Default: False.
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
def __init__(self,
layer_name_list,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
requires_grad=False,
remove_pooling=False,
pooling_stride=2):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.range_norm = range_norm
self.names = NAMES[vgg_type.replace('_bn', '')]
if 'bn' in vgg_type:
self.names = insert_bn(self.names)
# only borrow layers that will be used to avoid unused params
max_idx = 0
for v in layer_name_list:
idx = self.names.index(v)
if idx > max_idx:
max_idx = idx
if os.path.exists(VGG_PRETRAIN_PATH):
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
vgg_net.load_state_dict(state_dict)
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
features = vgg_net.features[:max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
if 'pool' in k:
# if remove_pooling is true, pooling operation will be removed
if remove_pooling:
continue
else:
# in some cases, we may want to change the default stride
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
else:
modified_net[k] = v
self.vgg_net = nn.Sequential(modified_net)
if not requires_grad:
self.vgg_net.eval()
for param in self.parameters():
param.requires_grad = False
else:
self.vgg_net.train()
for param in self.parameters():
param.requires_grad = True
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.range_norm:
x = (x + 1) / 2
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for key, layer in self.vgg_net._modules.items():
x = layer(x)
if key in self.layer_name_list:
output[key] = x.clone()
return output
================================================
FILE: basicsr/archs/vqgan_arch.py
================================================
'''
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script
def swish(x):
return x*torch.sigmoid(x)
# Define VQVAE classes
class VectorQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, beta):
super(VectorQuantizer, self).__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.emb_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
2 * torch.matmul(z_flattened, self.embedding.weight.t())
mean_distance = torch.mean(d)
# find closest encodings
min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
# min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
# [0-1], higher score, higher confidence
# min_encoding_scores = torch.exp(-min_encoding_scores/10)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
# compute loss for embedding
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
# preserve gradients
z_q = z + (z_q - z).detach()
# perplexity
e_mean = torch.mean(min_encodings, dim=0)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, loss, {
"perplexity": perplexity,
"min_encodings": min_encodings,
"min_encoding_indices": min_encoding_indices,
"mean_distance": mean_distance
}
def get_codebook_feat(self, indices, shape):
# input indices: batch*token_num -> (batch*token_num)*1
# shape: batch, height, width, channel
indices = indices.view(-1,1)
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
min_encodings.scatter_(1, indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
if shape is not None: # reshape back to match original input shape
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
return z_q
class GumbelQuantizer(nn.Module):
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
super().__init__()
self.codebook_size = codebook_size # number of embeddings
self.emb_dim = emb_dim # dimension of embedding
self.straight_through = straight_through
self.temperature = temp_init
self.kl_weight = kl_weight
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
self.embed = nn.Embedding(codebook_size, emb_dim)
def forward(self, z):
hard = self.straight_through if self.training else True
logits = self.proj(z)
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
# + kl divergence to the prior loss
qy = F.softmax(logits, dim=1)
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
min_encoding_indices = soft_one_hot.argmax(dim=1)
return z_q, diff, {
"min_encoding_indices": min_encoding_indices
}
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels=None):
super(ResBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = normalize(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = normalize(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x_in):
x = x_in
x = self.norm1(x)
x = swish(x)
x = self.conv1(x)
x = self.norm2(x)
x = swish(x)
x = self.conv2(x)
if self.in_channels != self.out_channels:
x_in = self.conv_out(x_in)
return x + x_in
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x+h_
class Encoder(nn.Module):
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
super().__init__()
self.nf = nf
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.attn_resolutions = attn_resolutions
curr_res = self.resolution
in_ch_mult = (1,)+tuple(ch_mult)
blocks = []
# initial convultion
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
# residual and downsampling blocks, with attention on smaller res (16x16)
for i in range(self.num_resolutions):
block_in_ch = nf * in_ch_mult[i]
block_out_ch = nf * ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != self.num_resolutions - 1:
blocks.append(Downsample(block_in_ch))
curr_res = curr_res // 2
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
# normalise and convert to latent size
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__()
self.nf = nf
self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim
self.out_channels = 3
block_in_ch = self.nf * self.ch_mult[-1]
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
blocks = []
# initial conv
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
# non-local attention block
blocks.append(ResBlock(block_in_ch, block_in_ch))
blocks.append(AttnBlock(block_in_ch))
blocks.append(ResBlock(block_in_ch, block_in_ch))
for i in reversed(range(self.num_resolutions)):
block_out_ch = self.nf * self.ch_mult[i]
for _ in range(self.num_res_blocks):
blocks.append(ResBlock(block_in_ch, block_out_ch))
block_in_ch = block_out_ch
if curr_res in self.attn_resolutions:
blocks.append(AttnBlock(block_in_ch))
if i != 0:
blocks.append(Upsample(block_in_ch))
curr_res = curr_res * 2
blocks.append(normalize(block_in_ch))
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__()
logger = get_root_logger()
self.in_channels = 3
self.nf = nf
self.n_blocks = res_blocks
self.codebook_size = codebook_size
self.embed_dim = emb_dim
self.ch_mult = ch_mult
self.resolution = img_size
self.attn_resolutions = attn_resolutions
self.quantizer_type = quantizer
self.encoder = Encoder(
self.in_channels,
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if self.quantizer_type == "nearest":
self.beta = beta #0.25
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
elif self.quantizer_type == "gumbel":
self.gumbel_num_hiddens = emb_dim
self.straight_through = gumbel_straight_through
self.kl_weight = gumbel_kl_weight
self.quantize = GumbelQuantizer(
self.codebook_size,
self.embed_dim,
self.gumbel_num_hiddens,
self.straight_through,
self.kl_weight
)
self.generator = Generator(
self.nf,
self.embed_dim,
self.ch_mult,
self.n_blocks,
self.resolution,
self.attn_resolutions
)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_ema' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
logger.info(f'vqgan is loaded from: {model_path} [params]')
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
x = self.encoder(x)
quant, codebook_loss, quant_stats = self.quantize(x)
x = self.generator(quant)
return x, codebook_loss, quant_stats
# patch based discriminator
@ARCH_REGISTRY.register()
class VQGANDiscriminator(nn.Module):
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
super().__init__()
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
ndf_mult = 1
ndf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
ndf_mult_prev = ndf_mult
ndf_mult = min(2 ** n_layers, 8)
layers += [
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
nn.BatchNorm2d(ndf * ndf_mult),
nn.LeakyReLU(0.2, True)
]
layers += [
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
self.main = nn.Sequential(*layers)
if model_path is not None:
chkpt = torch.load(model_path, map_location='cpu')
if 'params_d' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
elif 'params' in chkpt:
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
else:
raise ValueError(f'Wrong params!')
def forward(self, x):
return self.main(x)
================================================
FILE: basicsr/data/__init__.py
================================================
import importlib
import numpy as np
import random
import torch
import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import DATASET_REGISTRY
__all__ = ['build_dataset', 'build_dataloader']
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
def build_dataset(dataset_opt):
"""Build dataset from options.
Args:
dataset_opt (dict): Configuration for dataset. It must constain:
name (str): Dataset name.
type (str): Dataset type.
"""
dataset_opt = deepcopy(dataset_opt)
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
logger = get_root_logger()
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
return dataset
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
"""Build dataloader.
Args:
dataset (torch.utils.data.Dataset): Dataset.
dataset_opt (dict): Dataset options. It contains the following keys:
phase (str): 'train' or 'val'.
num_worker_per_gpu (int): Number of workers for each GPU.
batch_size_per_gpu (int): Training batch size for each GPU.
num_gpu (int): Number of GPUs. Used only in the train phase.
Default: 1.
dist (bool): Whether in distributed training. Used only in the train
phase. Default: False.
sampler (torch.utils.data.sampler): Data sampler. Default: None.
seed (int | None): Seed. Default: None
"""
phase = dataset_opt['phase']
rank, _ = get_dist_info()
if phase == 'train':
if dist: # distributed training
batch_size = dataset_opt['batch_size_per_gpu']
num_workers = dataset_opt['num_worker_per_gpu']
else: # non-distributed training
multiplier = 1 if num_gpu == 0 else num_gpu
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
dataloader_args = dict(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
drop_last=True)
if sampler is None:
dataloader_args['shuffle'] = True
dataloader_args['worker_init_fn'] = partial(
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
prefetch_mode = dataset_opt.get('prefetch_mode')
if prefetch_mode == 'cpu': # CPUPrefetcher
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
logger = get_root_logger()
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
else:
# prefetch_mode=None: Normal dataloader
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
return torch.utils.data.DataLoader(**dataloader_args)
def worker_init_fn(worker_id, num_workers, rank, seed):
# Set the worker seed to num_workers * rank + worker_id + seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
================================================
FILE: basicsr/data/color_dataset.py
================================================
import os
import random
from pathlib import Path
from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
random_mixed_kernels)
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, brush_stroke_mask_video, random_ff_mask
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data
@DATASET_REGISTRY.register()
class ColorizationDataset(data.Dataset):
def __init__(self, opt):
super(ColorizationDataset, self).__init__()
self.opt = opt
self.gt_root = Path(opt['dataroot_gt'])
self.num_frame = opt['video_length'] # 5
self.scale = opt['scale'] # [1, 4]
self.need_align = opt.get('need_align', False) # False
self.normalize = opt.get('normalize', False) # True
self.keys = []
with open(opt['global_meta_info_file'], 'r') as fin:
for line in fin:
real_clip_path = '/'.join(line.split('/')[:-1])
clip_length = int(line.split('/')[-1])
self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
self.io_backend_opt['db_paths'] = [self.gt_root]
self.io_backend_opt['client_keys'] = ['gt']
# temporal augmentation configs
self.interval_list = opt['interval_list'] # [1]
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
# degradations
# blur
self.blur_kernel_size = opt['blur_kernel_size'] # 21
self.kernel_list = opt['kernel_list'] # ['iso', 'aniso']
self.kernel_prob = opt['kernel_prob'] # [0.5, 0.5]
self.blur_x_sigma = opt['blur_x_sigma'] # [0.2, 3]
self.blur_y_sigma = opt['blur_y_sigma'] # [0.2, 3]
# noise
self.noise_range = opt['noise_range'] # [0, 25]
# resize
self.resize_prob = opt['resize_prob'] # [0.25, 0.25, 0.5]
# crf
self.crf_range = opt['crf_range'] # [10, 30]
# codec
self.vcodec = opt['vcodec'] # ['libx264']
self.vcodec_prob = opt['vcodec_prob'] # [1]
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')
if self.need_align:
self.dataroot_meta_info = opt['dataroot_meta_info']
self.face_aligner = FaceAligner(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
key = self.keys[index]
real_clip_path = '/'.join(key.split('/')[:-2])
clip_length = int(key.split('/')[-2])
frame_idx = int(key.split('/')[-1])
clip_name = real_clip_path.split('/')[-1]
if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
else:
paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))
# determine the neighboring frames
interval = random.choice(self.interval_list)
# exceed the length, re-select a new clip
while (clip_length - self.num_frame * interval) < 0:
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
while (start_frame_idx < 0) or (end_frame_idx > clip_length):
frame_idx = random.randint(self.num_frame // 2 * interval,
clip_length - self.num_frame // 2 * interval)
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == self.num_frame, (
f'Wrong length of neighbor list: {len(neighbor_list)}')
# get the neighboring GT frames
img_gts = []
need_align = False
if self.need_align:
clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
if os.path.exists(clip_info_path):
need_align = True
clip_info = []
with open(clip_info_path, 'r', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
clip_info.append(line)
for neighbor in neighbor_list:
img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])
img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
img_gts.append(img_gt)
# augmentation - flip, rotate
img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False
# ------------- generate grayscale frames --------------#
img_lqs = img_gts
img_lqs = [cv2.cvtColor((_ * 255).astype('uint8'), cv2.COLOR_BGR2GRAY) for _ in img_lqs]
img_lqs = [np.repeat(_[..., None], repeats=3, axis=2) / 255. for _ in img_lqs]
# -------------- Align ---------------#
if need_align:
align_lqs, align_gts = [], []
for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
self.face_aligner.clean_all()
# align and warp each face
img_lq, img_gt = self.face_aligner.align_pair_face(img_lq, img_gt, landmarks)
align_lqs.append(img_lq)
align_gts.append(img_gt)
img_lqs, img_gts = align_lqs, align_gts
img_gts = img2tensor(img_gts)
img_lqs = img2tensor(img_lqs)
img_gts = torch.stack(img_gts, dim=0)
img_lqs = torch.stack(img_lqs, dim=0)
if self.normalize:
normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
return {'in': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)
================================================
FILE: basicsr/data/data_sampler.py
================================================
import math
import torch
from torch.utils.data.sampler import Sampler
class EnlargedSampler(Sampler):
"""Sampler that restricts data loading to a subset of the dataset.
Modified from torch.utils.data.distributed.DistributedSampler
Support enlarging the dataset for iteration-based training, for saving
time when restart the dataloader after each epoch
Args:
dataset (torch.utils.data.Dataset): Dataset used for sampling.
num_replicas (int | None): Number of processes participating in
the training. It is usually the world_size.
rank (int | None): Rank of the current process within num_replicas.
ratio (int): Enlarging ratio. Default: 1.
"""
def __init__(self, dataset, num_replicas, rank, ratio=1):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
indices = torch.randperm(self.total_size, generator=g).tolist()
dataset_size = len(self.dataset)
indices = [v % dataset_size for v in indices]
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
================================================
FILE: basicsr/data/data_util.py
================================================
import cv2
import math
import numpy as np
import torch
from os import path as osp
from PIL import Image, ImageDraw
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
from basicsr.utils import img2tensor, scandir
def read_img_seq(path, require_mod_crop=False, scale=1):
"""Read a sequence of images from a given folder path.
Args:
path (list[str] | str): List of image paths or image folder path.
require_mod_crop (bool): Require mod crop for each image.
Default: False.
scale (int): Scale factor for mod_crop. Default: 1.
Returns:
Tensor: size (t, c, h, w), RGB, [0, 1].
"""
if isinstance(path, list):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
imgs = torch.stack(imgs, dim=0)
return imgs
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
"""Generate an index list for reading `num_frames` frames from a sequence
of images.
Args:
crt_idx (int): Current center index.
max_frame_num (int): Max number of the sequence of images (from 1).
num_frames (int): Reading num_frames frames.
padding (str): Padding mode, one of
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
Examples: current_idx = 0, num_frames = 5
The generated frame indices under different padding mode:
replicate: [0, 0, 0, 1, 2]
reflection: [2, 1, 0, 1, 2]
reflection_circle: [4, 3, 0, 1, 2]
circle: [3, 4, 0, 1, 2]
Returns:
list[int]: A list of indices.
"""
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
max_frame_num = max_frame_num - 1 # start from 0
num_pad = num_frames // 2
indices = []
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
if i < 0:
if padding == 'replicate':
pad_idx = 0
elif padding == 'reflection':
pad_idx = -i
elif padding == 'reflection_circle':
pad_idx = crt_idx + num_pad - i
else:
pad_idx = num_frames + i
elif i > max_frame_num:
if padding == 'replicate':
pad_idx = max_frame_num
elif padding == 'reflection':
pad_idx = max_frame_num * 2 - i
elif padding == 'reflection_circle':
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
else:
pad_idx = i - num_frames
else:
pad_idx = i
indices.append(pad_idx)
return indices
def paired_paths_from_lmdb(folders, keys):
"""Generate paired paths from lmdb files.
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
lq.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records
1)image name (with extension),
2)image shape,
3)compression level, separated by a white space.
Example: `baboon.png (120,125,3) 1`
We use the image name without extension as the lmdb key.
Note that we use the same key for the corresponding lq and gt images.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
Note that this key is different from lmdb keys.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
f'formats. But received {input_key}: {input_folder}; '
f'{gt_key}: {gt_folder}')
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
gt_lmdb_keys = [line.split('.')[0] for line in fin]
if set(input_lmdb_keys) != set(gt_lmdb_keys):
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
else:
paths = []
for lmdb_key in sorted(input_lmdb_keys):
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
return paths
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
"""Generate paired paths from an meta information file.
Each line in the meta information file contains the image names and
image shape (usually for gt), separated by a white space.
Example of an meta information file:
```
0001_s001.png (480,480,3)
0001_s002.png (480,480,3)
```
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
meta_info_file (str): Path to the meta information file.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
with open(meta_info_file, 'r') as fin:
gt_names = [line.split(' ')[0] for line in fin]
paths = []
for gt_name in gt_names:
basename, ext = osp.splitext(osp.basename(gt_name))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
gt_path = osp.join(gt_folder, gt_name)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paired_paths_from_folder(folders, keys, filename_tmpl):
"""Generate paired paths from folders.
Args:
folders (list[str]): A list of folder path. The order of list should
be [input_folder, gt_folder].
keys (list[str]): A list of keys identifying folders. The order should
be in consistent with folders, e.g., ['lq', 'gt'].
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Usually the filename_tmpl is
for files in the input folder.
Returns:
list[str]: Returned path list.
"""
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
f'But got {len(folders)}')
assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
f'{len(input_paths)}, {len(gt_paths)}.')
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
input_name = f'{filename_tmpl.format(basename)}{ext}'
input_path = osp.join(input_folder, input_name)
assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
gt_path = osp.join(gt_folder, gt_path)
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
return paths
def paths_from_folder(folder):
"""Generate paths from folder.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
paths = list(scandir(folder))
paths = [osp.join(folder, path) for path in paths]
return paths
def paths_from_lmdb(folder):
"""Generate paths from lmdb.
Args:
folder (str): Folder path.
Returns:
list[str]: Returned path list.
"""
if not folder.endswith('.lmdb'):
raise ValueError(f'Folder {folder}folder should in lmdb format.')
with open(osp.join(folder, 'meta_info.txt')) as fin:
paths = [line.split('.')[0] for line in fin]
return paths
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
"""Generate Gaussian kernel used in `duf_downsample`.
Args:
kernel_size (int): Kernel size. Default: 13.
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
Returns:
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
# gaussian-smooth the dirac, resulting in a gaussian filter
return filters.gaussian_filter(kernel, sigma)
def duf_downsample(x, kernel_size=13, scale=4):
"""Downsamping with Gaussian kernel used in the DUF official code.
Args:
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
kernel_size (int): Kernel size. Default: 13.
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
Default: 4.
Returns:
Tensor: DUF downsampled frames.
"""
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
squeeze_flag = False
if x.ndim == 4:
squeeze_flag = True
x = x.unsqueeze(0)
b, t, c, h, w = x.size()
x = x.view(-1, 1, h, w)
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
x = F.conv2d(x, gaussian_filter, stride=scale)
x = x[:, :, 2:-2, 2:-2]
x = x.view(b, t, c, x.size(2), x.size(3))
if squeeze_flag:
x = x.squeeze(0)
return x
def brush_stroke_mask(img, color=(255,255,255)):
min_num_vertex = 8
max_num_vertex = 28
mean_angle = 2*math.pi / 5
angle_range = 2*math.pi / 12
# training large mask ratio (training setting)
min_width = 30
max_width = 70
# very large mask ratio (test setting and refine after 200k)
# min_width = 80
# max_width = 120
def generate_mask(H, W, img=None):
average_radius = math.sqrt(H*H+W*W) / 8
mask = Image.new('RGB', (W, H), 0)
if img is not None:
mask = img # Image.fromarray(img)
for _ in range(np.random.randint(1, 4)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = mask.size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
draw = ImageDraw.Draw(mask)
width = int(np.random.uniform(min_width, max_width))
draw.line(vertex, fill=color, width=width)
for v in vertex:
draw.ellipse((v[0] - width//2,
v[1] - width//2,
v[0] + width//2,
v[1] + width//2),
fill=color)
return mask
width, height = img.size
mask = generate_mask(height, width, img)
return mask
def brush_stroke_mask_video(imgs, color=(255,255,255)):
min_num_vertex = 8
max_num_vertex = 28
mean_angle = 2 * math.pi / 5
angle_range = 2 * math.pi / 12
# training large mask ratio (training setting)
min_width = 30
max_width = 70
# very large mask ratio (test setting and refine after 200k)
# min_width = 80
# max_width = 120
def generate_mask(H, W, imgs=None):
average_radius = math.sqrt(H*H+W*W) / 8
# mask = Image.new('RGB', (W, H), 0)
# if img is not None:
# mask = img # Image.fromarray(img)
for _ in range(np.random.randint(1, 4)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = imgs[0].size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
width_ = int(np.random.uniform(min_width, max_width))
for img in imgs:
draw = ImageDraw.Draw(img)
draw.line(vertex, fill=color, width=width_)
for v in vertex:
draw.ellipse((v[0] - width_//2,
v[1] - width_//2,
v[0] + width_//2,
v[1] + width_//2),
fill=color)
return imgs
width, height = imgs[0].size
mask = generate_mask(height, width, imgs)
return mask
def random_ff_mask(shape, max_angle = 10, max_len = 100, max_width = 70, times = 10):
"""Generate a random free form mask with configuration.
Args:
config: Config should have configuration including IMG_SHAPES,
VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.
Returns:
tuple: (top, left, height, width)
Link:
https://github.com/csqiangwen/DeepFillv2_Pytorch/blob/master/train_dataset.py
"""
height = shape[0]
width = shape[1]
mask = np.zeros((height, width), np.float32)
times = np.random.randint(times-5, times)
for i in range(times):
start_x = np.random.randint(width)
start_y = np.random.randint(height)
for j in range(1 + np.random.randint(5)):
angle = 0.01 + np.random.randint(max_angle)
if i % 2 == 0:
angle = 2 * 3.1415926 - angle
length = 10 + np.random.randint(max_len-20, max_len)
brush_w = 5 + np.random.randint(max_width-30, max_width)
end_x = (start_x + length * np.sin(angle)).astype(np.int32)
end_y = (start_y + length * np.cos(angle)).astype(np.int32)
cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w)
start_x, start_y = end_x, end_y
return mask.astype(np.float32)
================================================
FILE: basicsr/data/degradations.py
================================================
import cv2
import math
import numpy as np
import random
import torch
from scipy import special
from scipy.stats import multivariate_normal
from torchvision.transforms.functional import rgb_to_grayscale
# -------------------------------------------------------------------- #
# --------------------------- blur kernels --------------------------- #
# -------------------------------------------------------------------- #
# --------------------------- util functions --------------------------- #
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def cdf2(d_matrix, grid):
"""Calculate the CDF of the standard bivariate Gaussian distribution.
Used in skewed Gaussian distribution.
Args:
d_matrix (ndarrasy): skew matrix.
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
cdf (ndarray): skewed cdf.
"""
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
grid = np.dot(grid, d_matrix)
cdf = rv.cdf(grid)
return cdf
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
isotropic (bool):
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a bivariate generalized Gaussian kernel.
``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions``
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
if isotropic:
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
else:
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# assume beta_range[0] < 1 < beta_range[1]
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_plateau(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
isotropic=True):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
if isotropic is False:
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
else:
sigma_y = sigma_x
rotation = 0
# TODO: this may be not proper
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=(0.6, 5),
sigma_y_range=(0.6, 5),
rotation_range=(-math.pi, math.pi),
betag_range=(0.5, 8),
betap_range=(0.5, 8),
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kernel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each
kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise,
[0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=True)
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
betag_range,
noise_range=noise_range,
isotropic=False)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
return kernel
np.seterr(divide='ignore', invalid='ignore')
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""2D sinc filter
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
Args:
cutoff (float): cutoff frequency in radians (pi is max)
kernel_size (int): horizontal and vertical size, must be odd.
pad_to (int): pad kernel size to desired size, must be odd or zero.
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
pad_size = (pad_to - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
return kernel
# ------------------------------------------------------------- #
# --------------------------- noise --------------------------- #
# ------------------------------------------------------------- #
# ----------------------- Gaussian Noise ----------------------- #
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
"""Generate Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
return noise
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
"""Add Gaussian noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
sigma (float): Noise scale (measured in range 255). Default: 10.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_gaussian_noise(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if not isinstance(sigma, (float, int)):
sigma = sigma.view(img.size(0), 1, 1, 1)
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
return noise
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
"""Add Gaussian noise (PyTorch version).
Args:
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
scale (float | Tensor): Noise scale. Default: 1.0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Gaussian Noise ----------------------- #
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_gaussian_noise(img, sigma, gray_noise)
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
sigma = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Poisson (Shot) Noise ----------------------- #
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
"""Generate poisson noise.
Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
if gray_noise:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# round and clip image for counting vals correctly
img = np.clip((img * 255.0).round(), 0, 255) / 255.
vals = len(np.unique(img))
vals = 2**np.ceil(np.log2(vals))
out = np.float32(np.random.poisson(img * vals) / float(vals))
noise = out - img
if gray_noise:
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
return noise * scale
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
"""Add poisson noise.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
scale (float): Noise scale. Default: 1.0.
gray_noise (bool): Whether generate gray noise. Default: False.
Returns:
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
float32.
"""
noise = generate_poisson_noise(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
"""Generate a batch of poisson noise (PyTorch version)
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
b, _, h, w = img.size()
if isinstance(gray_noise, (float, int)):
cal_gray_noise = gray_noise > 0
else:
gray_noise = gray_noise.view(b, 1, 1, 1)
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
img_gray = rgb_to_grayscale(img, num_output_channels=1)
# round and clip image for counting vals correctly
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
noise_gray = noise_gray.expand(b, 3, h, w)
# always calculate color noise
# round and clip image for counting vals correctly
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
if not isinstance(scale, (float, int)):
scale = scale.view(b, 1, 1, 1)
return noise * scale
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
"""Add poisson noise to a batch of images (PyTorch version).
Args:
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
Default: 1.0.
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
0 for False, 1 for True. Default: 0.
Returns:
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
float32.
"""
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
scale = np.random.uniform(scale_range[0], scale_range[1])
if np.random.uniform() < gray_prob:
gray_noise = True
else:
gray_noise = False
return generate_poisson_noise(img, scale, gray_noise)
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = np.clip((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
scale = torch.rand(
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
out = (out * 255.0).round() / 255.
return out
# ------------------------------------------------------------------------ #
# --------------------------- JPEG compression --------------------------- #
# ------------------------------------------------------------------------ #
def add_jpg_compression(img, quality=90):
"""Add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality (float): JPG compression quality. 0 for lowest quality, 100 for
best quality. Default: 90.
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
return img
def random_add_jpg_compression(img, quality_range=(90, 100)):
"""Randomly add JPG compression artifacts.
Args:
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
quality_range (tuple[float] | list[float]): JPG compression quality
range. 0 for lowest quality, 100 for best quality.
Default: (90, 100).
Returns:
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
float32.
"""
quality = np.random.uniform(quality_range[0], quality_range[1])
return add_jpg_compression(img, quality)
================================================
FILE: basicsr/data/gaussian_kernels.py
================================================
import math
import numpy as np
import random
from scipy.ndimage.interpolation import shift
from scipy.stats import multivariate_normal
def sigma_matrix2(sig_x, sig_y, theta):
"""Calculate the rotated sigma matrix (two dimensional matrix).
Args:
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
Returns:
ndarray: Rotated sigma matrix.
"""
D = np.array([[sig_x**2, 0], [0, sig_y**2]])
U = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
return np.dot(U, np.dot(D, U.T))
def mesh_grid(kernel_size):
"""Generate the mesh grid, centering at zero.
Args:
kernel_size (int):
Returns:
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
xx, yy = np.meshgrid(ax, ax)
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)),
yy.reshape(kernel_size * kernel_size,
1))).reshape(kernel_size, kernel_size, 2)
return xy, xx, yy
def pdf2(sigma_matrix, grid):
"""Calculate PDF of the bivariate Gaussian distribution.
Args:
sigma_matrix (ndarray): with the shape (2, 2)
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
kernel (ndarrray): un-normalized kernel.
"""
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
return kernel
def cdf2(D, grid):
"""Calculate the CDF of the standard bivariate Gaussian distribution.
Used in skewed Gaussian distribution.
Args:
D (ndarrasy): skew matrix.
grid (ndarray): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size.
Returns:
cdf (ndarray): skewed cdf.
"""
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
grid = np.dot(grid, D)
cdf = rv.cdf(grid)
return cdf
def bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid=None):
"""Generate a bivariate skew Gaussian kernel.
Described in `A multivariate skew normal distribution`_ by Shi et. al (2004).
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
D (ndarrasy): skew matrix.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
.. _A multivariate skew normal distribution:
https://www.sciencedirect.com/science/article/pii/S0047259X03001313
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
pdf = pdf2(sigma_matrix, grid)
cdf = cdf2(D, grid)
kernel = pdf * cdf
kernel = kernel / np.sum(kernel)
return kernel
def mass_center_shift(kernel_size, kernel):
"""Calculate the shift of the mass center of a kenrel.
Args:
kernel_size (int):
kernel (ndarray): normalized kernel.
Returns:
delta_h (float):
delta_w (float):
"""
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
col_sum, row_sum = np.sum(kernel, axis=0), np.sum(kernel, axis=1)
delta_h = np.dot(row_sum, ax)
delta_w = np.dot(col_sum, ax)
return delta_h, delta_w
def bivariate_skew_Gaussian_center(kernel_size,
sig_x,
sig_y,
theta,
D,
grid=None):
"""Generate a bivariate skew Gaussian kernel at center. Shift with nearest padding.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
D (ndarrasy): skew matrix.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): centered and normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
kernel = bivariate_skew_Gaussian(kernel_size, sig_x, sig_y, theta, D, grid)
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
kernel = shift(kernel, [-delta_h, -delta_w], mode='nearest')
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_anisotropic_Gaussian(kernel_size,
sig_x,
sig_y,
theta,
grid=None):
"""Generate a bivariate anisotropic Gaussian kernel.
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_isotropic_Gaussian(kernel_size, sig, grid=None):
"""Generate a bivariate isotropic Gaussian kernel.
Args:
kernel_size (int):
sig (float):
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
kernel = pdf2(sigma_matrix, grid)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_generalized_Gaussian(kernel_size,
sig_x,
sig_y,
theta,
beta,
grid=None):
"""Generate a bivariate generalized Gaussian kernel.
Described in `Parameter Estimation For Multivariate Generalized Gaussian Distributions`_
by Pascal et. al (2013).
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
.. _Parameter Estimation For Multivariate Generalized Gaussian Distributions:
https://arxiv.org/abs/1302.6498
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.exp(
-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau_type1(kernel_size, sig_x, sig_y, theta, beta, grid=None):
"""Generate a plateau-like anisotropic kernel.
1 / (1+x^(beta))
Args:
kernel_size (int):
sig_x (float):
sig_y (float):
theta (float): Radian measurement.
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def bivariate_plateau_type1_iso(kernel_size, sig, beta, grid=None):
"""Generate a plateau-like isotropic kernel.
1 / (1+x^(beta))
Args:
kernel_size (int):
sig (float):
beta (float): shape parameter, beta = 1 is the normal distribution.
grid (ndarray, optional): generated by :func:`mesh_grid`,
with the shape (K, K, 2), K is the kernel size. Default: None
Returns:
kernel (ndarray): normalized kernel.
"""
if grid is None:
grid, _, _ = mesh_grid(kernel_size)
sigma_matrix = np.array([[sig**2, 0], [0, sig**2]])
inverse_sigma = np.linalg.inv(sigma_matrix)
kernel = np.reciprocal(
np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
kernel = kernel / np.sum(kernel)
return kernel
def random_bivariate_skew_Gaussian_center(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate skew Gaussian kernels at center.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
if strict:
sigma_max = np.max([sigma_x, sigma_y])
sigma_min = np.min([sigma_x, sigma_y])
sigma_x, sigma_y = sigma_max, sigma_min
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
sigma_max = np.max([sigma_x, sigma_y])
thres = 3 / sigma_max
D = [[np.random.uniform(-thres, thres),
np.random.uniform(-thres, thres)],
[np.random.uniform(-thres, thres),
np.random.uniform(-thres, thres)]]
kernel = bivariate_skew_Gaussian_center(kernel_size, sigma_x, sigma_y,
rotation, D)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma_x, sigma_y, rotation, D
else:
return kernel
def random_bivariate_anisotropic_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate anisotropic Gaussian kernels.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
if strict:
sigma_max = np.max([sigma_x, sigma_y])
sigma_min = np.min([sigma_x, sigma_y])
sigma_x, sigma_y = sigma_max, sigma_min
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
kernel = bivariate_anisotropic_Gaussian(kernel_size, sigma_x, sigma_y,
rotation)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma_x, sigma_y, rotation
else:
return kernel
def random_bivariate_isotropic_Gaussian(kernel_size,
sigma_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate isotropic Gaussian kernels.
Args:
kernel_size (int):
sigma_range (tuple): [0.6, 5]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
kernel = bivariate_isotropic_Gaussian(kernel_size, sigma)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma
else:
return kernel
def random_bivariate_generalized_Gaussian(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate generalized Gaussian kernels.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
if strict:
sigma_max = np.max([sigma_x, sigma_y])
sigma_min = np.min([sigma_x, sigma_y])
sigma_x, sigma_y = sigma_max, sigma_min
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y,
rotation, beta)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma_x, sigma_y, rotation, beta
else:
return kernel
def random_bivariate_plateau_type1(kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate plateau type1 kernels.
Args:
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi/2, math.pi/2]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
if strict:
sigma_max = np.max([sigma_x, sigma_y])
sigma_min = np.min([sigma_x, sigma_y])
sigma_x, sigma_y = sigma_max, sigma_min
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
if np.random.uniform() < 0.5:
beta = np.random.uniform(beta_range[0], 1)
else:
beta = np.random.uniform(1, beta_range[1])
kernel = bivariate_plateau_type1(kernel_size, sigma_x, sigma_y, rotation,
beta)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma_x, sigma_y, rotation, beta
else:
return kernel
def random_bivariate_plateau_type1_iso(kernel_size,
sigma_range,
beta_range,
noise_range=None,
strict=False):
"""Randomly generate bivariate plateau type1 kernels (iso).
Args:
kernel_size (int):
sigma_range (tuple): [0.6, 5]
beta_range (tuple): [1, 4]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
assert sigma_range[0] < sigma_range[1], 'Wrong sigma_x_range.'
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
beta = np.random.uniform(beta_range[0], beta_range[1])
kernel = bivariate_plateau_type1_iso(kernel_size, sigma, beta)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
if strict:
return kernel, sigma, beta
else:
return kernel
def random_mixed_kernels(kernel_list,
kernel_prob,
kernel_size=21,
sigma_x_range=[0.6, 5],
sigma_y_range=[0.6, 5],
rotation_range=[-math.pi, math.pi],
beta_range=[0.5, 8],
noise_range=None):
"""Randomly generate mixed kernels.
Args:
kernel_list (tuple): a list name of kenrel types,
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 'plateau_aniso']
kernel_prob (tuple): corresponding kernel probability for each kernel type
kernel_size (int):
sigma_x_range (tuple): [0.6, 5]
sigma_y_range (tuple): [0.6, 5]
rotation range (tuple): [-math.pi, math.pi]
beta_range (tuple): [0.5, 8]
noise_range(tuple, optional): multiplicative kernel noise, [0.75, 1.25]. Default: None
Returns:
kernel (ndarray):
"""
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_isotropic_Gaussian(
kernel_size, sigma_x_range, noise_range=noise_range)
elif kernel_type == 'aniso':
kernel = random_bivariate_anisotropic_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=noise_range)
elif kernel_type == 'skew':
kernel = random_bivariate_skew_Gaussian_center(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
noise_range=noise_range)
elif kernel_type == 'generalized':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=noise_range)
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau_type1_iso(
kernel_size, sigma_x_range, beta_range, noise_range=noise_range)
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau_type1(
kernel_size,
sigma_x_range,
sigma_y_range,
rotation_range,
beta_range,
noise_range=noise_range)
# add multiplicative noise
if noise_range is not None:
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
noise = np.random.uniform(
noise_range[0], noise_range[1], size=kernel.shape)
kernel = kernel * noise
kernel = kernel / np.sum(kernel)
return kernel
def show_one_kernel():
import matplotlib.pyplot as plt
kernel_size = 21
# bivariate skew Gaussian
D = [[0, 0], [0, 0]]
D = [[3 / 4, 0], [0, 0.5]]
kernel = bivariate_skew_Gaussian_center(kernel_size, 2, 4, -math.pi / 4, D)
# bivariate anisotropic Gaussian
kernel = bivariate_anisotropic_Gaussian(kernel_size, 2, 4, -math.pi / 4)
# bivariate anisotropic Gaussian
kernel = bivariate_isotropic_Gaussian(kernel_size, 1)
# bivariate generalized Gaussian
kernel = bivariate_generalized_Gaussian(
kernel_size, 2, 4, -math.pi / 4, beta=4)
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
print(delta_h, delta_w)
fig, axs = plt.subplots(nrows=2, ncols=2)
# axs.set_axis_off()
ax = axs[0][0]
im = ax.matshow(kernel, cmap='jet', origin='upper')
fig.colorbar(im, ax=ax)
# image
ax = axs[0][1]
kernel_vis = kernel - np.min(kernel)
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
ax.imshow(kernel_vis, interpolation='nearest')
_, xx, yy = mesh_grid(kernel_size)
# contour
ax = axs[1][0]
CS = ax.contour(xx, yy, kernel, origin='upper')
ax.clabel(CS, inline=1, fontsize=3)
# contourf
ax = axs[1][1]
kernel = kernel / np.max(kernel)
p = ax.contourf(
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
fig.colorbar(p)
plt.show()
def show_plateau_kernel():
import matplotlib.pyplot as plt
kernel_size = 21
kernel = plateau_type1(kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
kernel_norm = bivariate_isotropic_Gaussian(kernel_size, 5)
kernel_gau = bivariate_generalized_Gaussian(
kernel_size, 2, 4, -math.pi / 8, 2, grid=None)
delta_h, delta_w = mass_center_shift(kernel_size, kernel)
print(delta_h, delta_w)
# kernel_slice = kernel[10, :]
# kernel_gau_slice = kernel_gau[10, :]
# kernel_norm_slice = kernel_norm[10, :]
# fig, ax = plt.subplots()
# t = list(range(1, 22))
# ax.plot(t, kernel_gau_slice)
# ax.plot(t, kernel_slice)
# ax.plot(t, kernel_norm_slice)
# t = np.arange(0, 10, 0.1)
# y = np.exp(-0.5 * t)
# y2 = np.reciprocal(1 + t)
# print(t.shape)
# print(y.shape)
# ax.plot(t, y)
# ax.plot(t, y2)
# plt.show()
fig, axs = plt.subplots(nrows=2, ncols=2)
# axs.set_axis_off()
ax = axs[0][0]
im = ax.matshow(kernel, cmap='jet', origin='upper')
fig.colorbar(im, ax=ax)
# image
ax = axs[0][1]
kernel_vis = kernel - np.min(kernel)
kernel_vis = kernel_vis / np.max(kernel_vis) * 255.
ax.imshow(kernel_vis, interpolation='nearest')
_, xx, yy = mesh_grid(kernel_size)
# contour
ax = axs[1][0]
CS = ax.contour(xx, yy, kernel, origin='upper')
ax.clabel(CS, inline=1, fontsize=3)
# contourf
ax = axs[1][1]
kernel = kernel / np.max(kernel)
p = ax.contourf(
xx, yy, kernel, origin='upper', levels=np.linspace(-0.05, 1.05, 10))
fig.colorbar(p)
plt.show()
================================================
FILE: basicsr/data/inpainting_dataset.py
================================================
import os
import random
from pathlib import Path
from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
random_mixed_kernels)
from basicsr.data.data_util import paths_from_folder, brush_stroke_mask, brush_stroke_mask_video, random_ff_mask
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data
@DATASET_REGISTRY.register()
class InpaintingDataset(data.Dataset):
def __init__(self, opt):
super(InpaintingDataset, self).__init__()
self.opt = opt
self.gt_root = Path(opt['dataroot_gt'])
self.num_frame = opt['video_length'] # 5
self.scale = opt['scale'] # [1, 4]
self.need_align = opt.get('need_align', False) # False
self.normalize = opt.get('normalize', False) # True
self.keys = []
with open(opt['global_meta_info_file'], 'r') as fin:
for line in fin:
real_clip_path = '/'.join(line.split('/')[:-1])
clip_length = int(line.split('/')[-1])
self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
self.io_backend_opt['db_paths'] = [self.gt_root]
self.io_backend_opt['client_keys'] = ['gt']
# temporal augmentation configs
self.interval_list = opt['interval_list'] # [1]
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
# degradations
# blur
self.blur_kernel_size = opt['blur_kernel_size'] # 21
self.kernel_list = opt['kernel_list'] # ['iso', 'aniso']
self.kernel_prob = opt['kernel_prob'] # [0.5, 0.5]
self.blur_x_sigma = opt['blur_x_sigma'] # [0.2, 3]
self.blur_y_sigma = opt['blur_y_sigma'] # [0.2, 3]
# noise
self.noise_range = opt['noise_range'] # [0, 25]
# resize
self.resize_prob = opt['resize_prob'] # [0.25, 0.25, 0.5]
# crf
self.crf_range = opt['crf_range'] # [10, 30]
# codec
self.vcodec = opt['vcodec'] # ['libx264']
self.vcodec_prob = opt['vcodec_prob'] # [1]
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')
if self.need_align:
self.dataroot_meta_info = opt['dataroot_meta_info']
self.face_aligner = FaceAligner(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
key = self.keys[index]
real_clip_path = '/'.join(key.split('/')[:-2])
clip_length = int(key.split('/')[-2])
frame_idx = int(key.split('/')[-1])
clip_name = real_clip_path.split('/')[-1]
if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
else:
paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))
# determine the neighboring frames
interval = random.choice(self.interval_list)
# exceed the length, re-select a new clip
while (clip_length - self.num_frame * interval) < 0:
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
while (start_frame_idx < 0) or (end_frame_idx > clip_length):
frame_idx = random.randint(self.num_frame // 2 * interval,
clip_length - self.num_frame // 2 * interval)
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == self.num_frame, (
f'Wrong length of neighbor list: {len(neighbor_list)}')
# get the neighboring GT frames
img_gts = []
need_align = False
if self.need_align:
clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
if os.path.exists(clip_info_path):
need_align = True
clip_info = []
with open(clip_info_path, 'r', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
clip_info.append(line)
for neighbor in neighbor_list:
img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])
img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
img_gts.append(img_gt)
# augmentation - flip, rotate
img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False
# ------------- generate inpaint frames --------------#
img_lqs = img_gts
img_lqs = [Image.fromarray((_ * 255).astype('uint8')) for _ in img_lqs]
img_lqs = brush_stroke_mask_video(img_lqs)
img_lqs = [np.array(_) / 255. for _ in img_lqs]
# ------------ Align -------------#
if need_align:
align_lqs, align_gts = [], []
for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
self.face_aligner.clean_all()
# align and warp each face
img_lq, img_gt = self.face_aligner.align_pair_face(img_lq, img_gt, landmarks)
align_lqs.append(img_lq)
align_gts.append(img_gt)
img_lqs, img_gts = align_lqs, align_gts
img_gts = img2tensor(img_gts)
img_lqs = img2tensor(img_lqs)
img_gts = torch.stack(img_gts, dim=0)
img_lqs = torch.stack(img_lqs, dim=0)
if self.normalize:
normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
return {'in': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)
================================================
FILE: basicsr/data/paired_image_dataset.py
================================================
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
@DATASET_REGISTRY.register()
class PairedImageDataset(data.Dataset):
"""Paired image dataset for image restoration.
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and
GT image pairs.
There are three modes:
1. 'lmdb': Use lmdb files.
If opt['io_backend'] == lmdb.
2. 'meta_info_file': Use meta information file to generate paths.
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
3. 'folder': Scan folders to generate paths.
The rest.
Args:
opt (dict): Config for train datasets. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_lq (str): Data root path for lq.
meta_info_file (str): Path for meta information file.
io_backend (dict): IO backend type and other kwarg.
filename_tmpl (str): Template for each filename. Note that the
template excludes the file extension. Default: '{}'.
gt_size (int): Cropped patched size for gt patches.
use_flip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
scale (bool): Scale, which will be added automatically.
phase (str): 'train' or 'val'.
"""
def __init__(self, opt):
super(PairedImageDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.mean = opt['mean'] if 'mean' in opt else None
self.std = opt['std'] if 'std' in opt else None
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
if 'filename_tmpl' in opt:
self.filename_tmpl = opt['filename_tmpl']
else:
self.filename_tmpl = '{}'
if self.io_backend_opt['type'] == 'lmdb':
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
self.opt['meta_info_file'], self.filename_tmpl)
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
scale = self.opt['scale']
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32.
gt_path = self.paths[index]['gt_path']
img_bytes = self.file_client.get(gt_path, 'gt')
img_gt = imfrombytes(img_bytes, float32=True)
lq_path = self.paths[index]['lq_path']
img_bytes = self.file_client.get(lq_path, 'lq')
img_lq = imfrombytes(img_bytes, float32=True)
# augmentation for training
if self.opt['phase'] == 'train':
gt_size = self.opt['gt_size']
# random crop
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
# TODO: color space transform
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
# normalize
if self.mean is not None or self.std is not None:
normalize(img_lq, self.mean, self.std, inplace=True)
normalize(img_gt, self.mean, self.std, inplace=True)
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
def __len__(self):
return len(self.paths)
================================================
FILE: basicsr/data/prefetch_dataloader.py
================================================
import queue as Queue
import threading
import torch
from torch.utils.data import DataLoader
class PrefetchGenerator(threading.Thread):
"""A general prefetch generator.
Ref:
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
Args:
generator: Python generator.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, generator, num_prefetch_queue):
threading.Thread.__init__(self)
self.queue = Queue.Queue(num_prefetch_queue)
self.generator = generator
self.daemon = True
self.start()
def run(self):
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def __next__(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class PrefetchDataLoader(DataLoader):
"""Prefetch version of dataloader.
Ref:
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
TODO:
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
ddp.
Args:
num_prefetch_queue (int): Number of prefetch queue.
kwargs (dict): Other arguments for dataloader.
"""
def __init__(self, num_prefetch_queue, **kwargs):
self.num_prefetch_queue = num_prefetch_queue
super(PrefetchDataLoader, self).__init__(**kwargs)
def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
class CPUPrefetcher():
"""CPU prefetcher.
Args:
loader: Dataloader.
"""
def __init__(self, loader):
self.ori_loader = loader
self.loader = iter(loader)
def next(self):
try:
return next(self.loader)
except StopIteration:
return None
def reset(self):
self.loader = iter(self.ori_loader)
class CUDAPrefetcher():
"""CUDA prefetcher.
Ref:
https://github.com/NVIDIA/apex/issues/304#
It may consums more GPU memory.
Args:
loader: Dataloader.
opt (dict): Options.
"""
def __init__(self, loader, opt):
self.ori_loader = loader
self.loader = iter(loader)
self.opt = opt
self.stream = torch.cuda.Stream()
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.preload()
def preload(self):
try:
self.batch = next(self.loader) # self.batch is a dict
except StopIteration:
self.batch = None
return None
# put tensors to gpu
with torch.cuda.stream(self.stream):
for k, v in self.batch.items():
if torch.is_tensor(v):
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
self.preload()
return batch
def reset(self):
self.loader = iter(self.ori_loader)
self.preload()
================================================
FILE: basicsr/data/transforms.py
================================================
import cv2
import random
def mod_crop(img, scale):
"""Mod crop images, used during testing.
Args:
img (ndarray): Input image.
scale (int): Scale factor.
Returns:
ndarray: Result image.
"""
img = img.copy()
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
img = img[:h - h_remainder, :w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
"""Paired random crop.
It crops lists of lq and gt images with corresponding locations.
Args:
img_gts (list[ndarray] | ndarray): GT images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
should have the same shape. If the input is an ndarray, it will
be transformed to a list containing itself.
gt_patch_size (int): GT patch size.
scale (int): Scale factor.
gt_path (str): Path to ground-truth.
Returns:
list[ndarray] | ndarray: GT images and LQ images. If returned results
only have one element, just return ndarray.
"""
if not isinstance(img_gts, list):
img_gts = [img_gts]
if not isinstance(img_lqs, list):
img_lqs = [img_lqs]
h_lq, w_lq, _ = img_lqs[0].shape
h_gt, w_gt, _ = img_gts[0].shape
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
f'multiplication of LQ ({h_lq}, {w_lq}).')
if h_lq < lq_patch_size or w_lq < lq_patch_size:
raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
f'({lq_patch_size}, {lq_patch_size}). '
f'Please remove {gt_path}.')
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
left = random.randint(0, w_lq - lq_patch_size)
# crop lq patch
img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
img_lqs = img_lqs[0]
return img_gts, img_lqs
def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
"""Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
We use vertical flip and transpose for rotation implementation.
All the images in the list use the same augmentation.
Args:
imgs (list[ndarray] | ndarray): Images to be augmented. If the input
is an ndarray, it will be transformed to a list.
hflip (bool): Horizontal flip. Default: True.
rotation (bool): Ratotation. Default: True.
flows (list[ndarray]: Flows to be augmented. If the input is an
ndarray, it will be transformed to a list.
Dimension is (h, w, 2). Default: None.
return_status (bool): Return the status of flip and rotation.
Default: False.
Returns:
list[ndarray] | ndarray: Augmented images and flows. If returned
results only have one element, just return ndarray.
"""
hflip = hflip and random.random() < 0.5
vflip = rotation and random.random() < 0.5
rot90 = rotation and random.random() < 0.5
def _augment(img):
if hflip: # horizontal
cv2.flip(img, 1, img)
if vflip: # vertical
cv2.flip(img, 0, img)
if rot90:
img = img.transpose(1, 0, 2)
return img
def _augment_flow(flow):
if hflip: # horizontal
cv2.flip(flow, 1, flow)
flow[:, :, 0] *= -1
if vflip: # vertical
cv2.flip(flow, 0, flow)
flow[:, :, 1] *= -1
if rot90:
flow = flow.transpose(1, 0, 2)
flow = flow[:, :, [1, 0]]
return flow
if not isinstance(imgs, list):
imgs = [imgs]
imgs = [_augment(img) for img in imgs]
if len(imgs) == 1:
imgs = imgs[0]
if flows is not None:
if not isinstance(flows, list):
flows = [flows]
flows = [_augment_flow(flow) for flow in flows]
if len(flows) == 1:
flows = flows[0]
return imgs, flows
else:
if return_status:
return imgs, (hflip, vflip, rot90)
else:
return imgs
def img_rotate(img, angle, center=None, scale=1.0):
"""Rotate image.
Args:
img (ndarray): Image to be rotated.
angle (float): Rotation angle in degrees. Positive values mean
counter-clockwise rotation.
center (tuple[int]): Rotation center. If the center is None,
initialize it as the center of the image. Default: None.
scale (float): Isotropic scale factor. Default: 1.0.
"""
(h, w) = img.shape[:2]
if center is None:
center = (w // 2, h // 2)
matrix = cv2.getRotationMatrix2D(center, angle, scale)
rotated_img = cv2.warpAffine(img, matrix, (w, h))
return rotated_img
================================================
FILE: basicsr/data/vfhq_dataset.py
================================================
import os
import random
from pathlib import Path
from PIL import Image
import cv2
import ffmpeg
import io
import av
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from basicsr.data.degradations import (random_add_gaussian_noise,
random_mixed_kernels)
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, img2tensor, imfrombytes, scandir
from basicsr.utils.registry import DATASET_REGISTRY
from facelib.utils.face_restoration_helper import FaceAligner
from torch.utils import data as data
@DATASET_REGISTRY.register()
class VFHQRealDegradationDatasetNew(data.Dataset):
"""Support for blind setting adopted in paper. We excludes the random scale compared to GFPGAN.
This dataset is adopted in BasicVSR.
The degradation order is blur+downsample+noise
Directly read image by cv2. Generate LR images online.
NOTE: The specific degradation order is blur-noise-downsample-crf-upsample
The keys are generated from a meta info txt file.
Key format: subfolder-name/clip-length/frame-name
Key examples: "id00020#t0bbIRgKKzM#00381.txt#000.mp4/00000152/00000000"
GT (gt): Ground-Truth;
LQ (lq): Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames.
Args:
opt (dict): Config for train dataset. It contains the following keys:
dataroot_gt (str): Data root path for gt.
dataroot_clip_meta_info (srt): Data root path for meta info of each gt clip.
global_meta_info_file (str): Path for global meta information file.
io_backend (dict): IO backend type and other kwarg.
num_frame (int): Window size for input frames.
interval_list (list): Interval list for temporal augmentation.
random_reverse (bool): Random reverse input frames.
use_flip (bool): Use horizontal flips.
use_rot (bool): Use rotation (use vertical flip and transposing h
and w for implementation).
"""
def __init__(self, opt):
super(VFHQRealDegradationDatasetNew, self).__init__()
self.opt = opt
self.gt_root = Path(opt['dataroot_gt'])
self.num_frame = opt['video_length'] # 5
self.scale = opt['scale'] # [1, 4]
self.need_align = opt.get('need_align', False) # False
self.normalize = opt.get('normalize', False) # True
self.keys = []
with open(opt['global_meta_info_file'], 'r') as fin:
for line in fin:
real_clip_path = '/'.join(line.split('/')[:-1])
clip_length = int(line.split('/')[-1])
self.keys.extend([f'{real_clip_path}/{clip_length:08d}/{0:08d}'])
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.is_lmdb = False
if self.io_backend_opt['type'] == 'lmdb':
self.is_lmdb = True
self.io_backend_opt['db_paths'] = [self.gt_root]
self.io_backend_opt['client_keys'] = ['gt']
# temporal augmentation configs
self.interval_list = opt['interval_list'] # [1]
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list']) # '1'
logger = get_root_logger()
logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
f'random reverse is {self.random_reverse}.')
# degradations
# blur
self.blur_kernel_size = opt['blur_kernel_size'] # 21
self.kernel_list = opt['kernel_list'] # ['iso', 'aniso']
self.kernel_prob = opt['kernel_prob'] # [0.5, 0.5]
self.blur_x_sigma = opt['blur_x_sigma'] # [0.2, 3]
self.blur_y_sigma = opt['blur_y_sigma'] # [0.2, 3]
# noise
self.noise_range = opt['noise_range'] # [0, 25]
# resize
self.resize_prob = opt['resize_prob'] # [0.25, 0.25, 0.5]
# crf
self.crf_range = opt['crf_range'] # [10, 30]
# codec
self.vcodec = opt['vcodec'] # ['libx264']
self.vcodec_prob = opt['vcodec_prob'] # [1]
logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, '
f'x_sigma: [{", ".join(map(str, self.blur_x_sigma))}], '
f'y_sigma: [{", ".join(map(str, self.blur_y_sigma))}], ')
logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]')
logger.info(f'CRF compression: [{", ".join(map(str, self.crf_range))}]')
logger.info(f'Codec: [{", ".join(map(str, self.vcodec))}]')
if self.need_align:
self.dataroot_meta_info = opt['dataroot_meta_info']
self.face_aligner = FaceAligner(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
use_parse=True)
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(
self.io_backend_opt.pop('type'), **self.io_backend_opt)
key = self.keys[index]
real_clip_path = '/'.join(key.split('/')[:-2])
clip_length = int(key.split('/')[-2])
frame_idx = int(key.split('/')[-1])
clip_name = real_clip_path.split('/')[-1]
if os.path.exists(os.path.join(self.gt_root, "train", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "train", clip_name))))
elif os.path.exists(os.path.join(self.gt_root, "test", clip_name)):
paths = sorted(list(scandir(os.path.join(self.gt_root, "test", clip_name))))
else:
paths = sorted(list(scandir(os.path.join(self.gt_root, clip_name))))
# determine the neighboring frames
interval = random.choice(self.interval_list)
# exceed the length, re-select a new clip
while (clip_length - self.num_frame * interval) < 0:
interval = random.choice(self.interval_list)
# ensure not exceeding the borders
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
while (start_frame_idx < 0) or (end_frame_idx > clip_length):
frame_idx = random.randint(self.num_frame // 2 * interval,
clip_length - self.num_frame // 2 * interval)
start_frame_idx = frame_idx - self.num_frame // 2 * interval
end_frame_idx = frame_idx + (self.num_frame + 1) // 2 * interval
neighbor_list = list(range(start_frame_idx, end_frame_idx, interval))
# random reverse
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
assert len(neighbor_list) == self.num_frame, (
f'Wrong length of neighbor list: {len(neighbor_list)}')
# get the neighboring GT frames
img_gts = []
need_align = False
if self.need_align:
clip_info_path = os.path.join(self.dataroot_meta_info, f'{clip_name}.txt')
if os.path.exists(clip_info_path):
need_align = True
clip_info = []
with open(clip_info_path, 'r', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
clip_info.append(line)
for neighbor in neighbor_list:
img_gt_path = os.path.join(self.gt_root, clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "train", clip_name, paths[neighbor])
if not os.path.exists(img_gt_path):
img_gt_path = os.path.join(self.gt_root, "test", clip_name, paths[neighbor])
img_gt = np.asarray(Image.open(img_gt_path))[:, :, ::-1] / 255.0
img_gts.append(img_gt)
# augmentation - flip, rotate
img_gts = augment(img_gts, self.opt['use_flip'], self.opt['use_rot']) # False, False
# ------------- generate LQ frames --------------#
# add blur
kernel = random_mixed_kernels(self.kernel_list,
self.kernel_prob, # [0.7, 0.3]
self.blur_kernel_size, # 21
self.blur_x_sigma, # [0.1, 10]
self.blur_y_sigma) # [0.1, 10]
img_lqs = [cv2.filter2D(v, -1, kernel) for v in img_gts]
# downsample
ori_height, ori_width = img_gts[0].shape[0:2]
resize_type = random.choices([cv2.INTER_AREA,
cv2.INTER_LINEAR,
cv2.INTER_CUBIC], self.resize_prob)[0]
# ensure the resized_height and resized_width are even numbers
# scale = np.random.uniform(self.scale)
resized_height = int(ori_height // self.scale) // 2 * 2
resized_width = int(ori_width // self.scale) // 2 * 2
img_lqs = [cv2.resize(v, (resized_width, resized_height),
interpolation=resize_type) for v in img_lqs]
# add noise
img_lqs = [random_add_gaussian_noise(v,
self.noise_range, # [0, 10]
gray_prob=0.5,
clip=True,
rounds=False) for v in img_lqs] # noise_range: [0, 25]
# ffmpeg
crf = np.random.randint(self.crf_range[0], self.crf_range[1]) # [18, 25]
codec = random.choices(self.vcodec, self.vcodec_prob)[0] # 'libx264'
buf = io.BytesIO()
with av.open(buf, 'w', 'mp4') as container:
stream = container.add_stream(codec, rate=1)
stream.height = resized_height
stream.width = resized_width
stream.pix_fmt = 'yuv420p'
stream.options = {'crf': str(crf)}
for img_lq in img_lqs:
img_lq = np.clip(img_lq * 255, 0, 255).astype(np.uint8)
frame = av.VideoFrame.from_ndarray(img_lq, format='rgb24')
frame.pict_type = av.video.frame.PictureType.NONE
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
img_lqs = []
with av.open(buf, 'r', 'mp4') as container:
if container.streams.video:
for frame in container.decode(**{'video': 0}):
frame = frame.to_rgb().to_ndarray()
frame = cv2.resize(frame, (ori_width, ori_height), interpolation=resize_type) # upsample
img_lqs.append(frame / 255.)
assert len(img_lqs) == len(img_gts), 'Wrong length'
# ------------ Align -------------#
if need_align:
align_lqs, align_gts = [], []
for frame_idx, (img_lq, img_gt) in enumerate(zip(img_lqs, img_gts)):
landmarks_str = clip_info[start_frame_idx + frame_idx].split(' ')
landmarks = np.array([float(x) for x in landmarks_str]).reshape(5, 2)
self.face_aligner.clean_all()
# align and warp each face
img_lq, img_gt = self.face_aligner.align_pair_face(
img_lq, img_gt, landmarks)
align_lqs.append(img_lq)
align_gts.append(img_gt)
img_lqs, img_gts = align_lqs, align_gts
img_gts = img2tensor(img_gts)
img_lqs = img2tensor(img_lqs)
img_gts = torch.stack(img_gts, dim=0)
img_lqs = torch.stack(img_lqs, dim=0)
if self.normalize:
normalize(img_lqs, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
normalize(img_gts, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
return {'in': img_lqs, 'gt': img_gts, 'key': key}
def __len__(self):
return len(self.keys)
================================================
FILE: basicsr/losses/__init__.py
================================================
from copy import deepcopy
from basicsr.utils import get_root_logger
from basicsr.utils.registry import LOSS_REGISTRY
from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
gradient_penalty_loss, r1_penalty)
__all__ = [
'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
'r1_penalty', 'g_path_regularize'
]
def build_loss(opt):
"""Build loss from options.
Args:
opt (dict): Configuration. It must constain:
type (str): Model type.
"""
opt = deepcopy(opt)
loss_type = opt.pop('type')
loss = LOSS_REGISTRY.get(loss_type)(**opt)
logger = get_root_logger()
logger.info(f'Loss [{loss.__class__.__name__}] is created.')
return loss
================================================
FILE: basicsr/losses/loss_util.py
================================================
import functools
from torch.nn import functional as F
def reduce_loss(loss, reduction):
"""Reduce loss as specified.
Args:
loss (Tensor): Elementwise loss tensor.
reduction (str): Options are 'none', 'mean' and 'sum'.
Returns:
Tensor: Reduced loss tensor.
"""
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
else:
return loss.sum()
def weight_reduce_loss(loss, weight=None, reduction='mean'):
"""Apply element-wise weight and reduce loss.
Args:
loss (Tensor): Element-wise loss.
weight (Tensor): Element-wise weights. Default: None.
reduction (str): Same as built-in losses of PyTorch. Options are
'none', 'mean' and 'sum'. Default: 'mean'.
Returns:
Tensor: Loss values.
"""
# if weight is specified, apply element-wise weight
if weight is not None:
assert weight.dim() == loss.dim()
assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
loss = loss * weight
# if weight is not specified or reduction is sum, just reduce the loss
if weight is None or reduction == 'sum':
loss = reduce_loss(loss, reduction)
# if reduction is mean, then compute mean over weight region
elif reduction == 'mean':
if weight.size(1) > 1:
weight = weight.sum()
else:
weight = weight.sum() * loss.size(1)
loss = loss.sum() / weight
return loss
def weighted_loss(loss_func):
"""Create a weighted version of a given loss function.
To use this decorator, the loss function must have the signature like
`loss_func(pred, target, **kwargs)`. The function only needs to compute
element-wise loss without any reduction. This decorator will add weight
and reduction arguments to the function. The decorated function will have
the signature like `loss_func(pred, target, weight=None, reduction='mean',
**kwargs)`.
:Example:
>>> import torch
>>> @weighted_loss
>>> def l1_loss(pred, target):
>>> return (pred - target).abs()
>>> pred = torch.Tensor([0, 2, 3])
>>> target = torch.Tensor([1, 1, 1])
>>> weight = torch.Tensor([1, 0, 1])
>>> l1_loss(pred, target)
tensor(1.3333)
>>> l1_loss(pred, target, weight)
tensor(1.5000)
>>> l1_loss(pred, target, reduction='none')
tensor([1., 1., 2.])
>>> l1_loss(pred, target, weight, reduction='sum')
tensor(3.)
"""
@functools.wraps(loss_func)
def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
# get element-wise loss
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction)
return loss
return wrapper
================================================
FILE: basicsr/losses/losses.py
================================================
import math
import lpips
import torch
from torch import autograd as autograd
from torch import nn as nn
from torch.nn import functional as F
from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
from .loss_util import weighted_loss
# from basicsr.losses.loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@weighted_loss
def l1_loss(pred, target):
return F.l1_loss(pred, target, reduction='none')
@weighted_loss
def mse_loss(pred, target):
return F.mse_loss(pred, target, reduction='none')
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
return torch.sqrt((pred - target)**2 + eps)
@LOSS_REGISTRY.register()
class L1Loss(nn.Module):
"""L1 (mean absolute error, MAE) loss.
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(L1Loss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class MSELoss(nn.Module):
"""MSE (L2) loss.
Args:
loss_weight (float): Loss weight for MSE loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
"""
def __init__(self, loss_weight=1.0, reduction='mean'):
super(MSELoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
@LOSS_REGISTRY.register()
class CharbonnierLoss(nn.Module):
"""Charbonnier loss (one variant of Robust L1Loss, a differentiable
variant of L1Loss).
Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
Super-Resolution".
Args:
loss_weight (float): Loss weight for L1 loss. Default: 1.0.
reduction (str): Specifies the reduction to apply to the output.
Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
eps (float): A value used to control the curvature near zero.
Default: 1e-12.
"""
def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
super(CharbonnierLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.reduction = reduction
self.eps = eps
def forward(self, pred, target, weight=None, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise
weights. Default: None.
"""
return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
@LOSS_REGISTRY.register()
class WeightedTVLoss(L1Loss):
"""Weighted TV loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0):
super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
def forward(self, pred, weight=None):
y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
loss = x_diff + y_diff
return loss
@LOSS_REGISTRY.register()
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'conv5_4': 1.}, which means the conv5_4
feature layer (before relu5_4) will be extracted with weight
1.0 in calculting losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
Default: False.
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
loss will be calculated and the loss will multiplied by the
weight. Default: 1.0.
style_weight (float): If `style_weight > 0`, the style loss will be
calculated and the loss will multiplied by the weight.
Default: 0.
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
def __init__(self,
layer_weights,
vgg_type='vgg19',
use_input_norm=True,
range_norm=False,
perceptual_weight=1.0,
style_weight=0.,
criterion='l1'):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
self.layer_weights = layer_weights
self.vgg = VGGFeatureExtractor(
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
range_norm=range_norm)
self.criterion_type = criterion
if self.criterion_type == 'l1':
self.criterion = torch.nn.L1Loss()
elif self.criterion_type == 'l2':
self.criterion = torch.nn.L2loss()
elif self.criterion_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='mean')
elif self.criterion_type == 'fro':
self.criterion = None
else:
raise NotImplementedError(f'{criterion} criterion has not been supported.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate perceptual loss
if self.perceptual_weight > 0:
percep_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
else:
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss *= self.perceptual_weight
else:
percep_loss = None
# calculate style loss
if self.style_weight > 0:
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
style_loss += torch.norm(
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
else:
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
gt_features[k])) * self.layer_weights[k]
style_loss *= self.style_weight
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
n, c, h, w = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram
@LOSS_REGISTRY.register()
class LPIPSLoss(nn.Module):
def __init__(self,
loss_weight=1.0,
use_input_norm=True,
range_norm=False,):
super(LPIPSLoss, self).__init__()
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
self.loss_weight = loss_weight
self.use_input_norm = use_input_norm
self.range_norm = range_norm
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [0, 1]
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, pred, target):
if self.range_norm:
pred = (pred + 1) / 2
target = (target + 1) / 2
if self.use_input_norm:
pred = (pred - self.mean) / self.std
target = (target - self.mean) / self.std
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
return self.loss_weight * lpips_loss.mean()
@LOSS_REGISTRY.register()
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
target_label = self.get_target_label(input, target_is_real)
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight
def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
penalize the gradient on real data alone: when the
generator distribution produces the true data distribution
and the discriminator is equal to 0 on the data manifold, the
gradient penalty ensures that the discriminator cannot create
a non-zero gradient orthogonal to the data manifold without
suffering a loss in the GAN game.
Ref:
Eq. 9 in Which training methods for GANs do actually converge.
"""
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
path_penalty = (path_lengths - path_mean).pow(2).mean()
return path_penalty, path_lengths.detach().mean(), path_mean.detach()
def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
"""Calculate gradient penalty for wgan-gp.
Args:
discriminator (nn.Module): Network for the discriminator.
real_data (Tensor): Real input data.
fake_data (Tensor): Fake input data.
weight (Tensor): Weight tensor. Default: None.
Returns:
Tensor: A tensor for gradient penalty.
"""
batch_size = real_data.size(0)
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
interpolates = alpha * real_data + (1. - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
gradients = autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
if weight is not None:
gradients = gradients * weight
gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
return gradients_penalty
@LOSS_REGISTRY.register()
class DirichletKLLoss(nn.Module):
"""Dir distribution KL-loss.
Args:
loss_weight (float): Loss weight. Default: 1.0.
"""
def __init__(self, loss_weight=1.0, kl_coef=1.1):
super(DirichletKLLoss, self).__init__()
self.loss_weight = loss_weight
self.kl_coef = kl_coef
def forward(self, alpha):
beta = self.kl_coef * torch.ones_like(alpha)
l1 = torch.lgamma(alpha.sum(dim=-1, keepdim=True))
l2 = torch.lgamma(alpha).sum(dim=-1, keepdim=True)
l3 = (alpha - beta) * (torch.digamma(alpha) - torch.digamma(alpha.sum(dim=-1, keepdim=True)))
loss = l1 - l2 + l3.sum(dim=-1,keepdim=True)
loss = loss.mean()
if self.loss_weight > 0.1:
import torch.distributions as dist
dirichlet_dist = dist.Dirichlet(alpha)
parameters = dirichlet_dist.rsample()
maxium = torch.max(parameters, dim=-1)[0]
# print(maxium.mean())
return self.loss_weight*loss
if __name__ == '__main__':
LpipsLoss = LPIPSLoss()
================================================
FILE: basicsr/metrics/__init__.py
================================================
from copy import deepcopy
from basicsr.utils.registry import METRIC_REGISTRY
from .psnr_ssim import calculate_psnr, calculate_ssim
__all__ = ['calculate_psnr', 'calculate_ssim']
def calculate_metric(data, opt):
"""Calculate metric from data and options.
Args:
opt (dict): Configuration. It must constain:
type (str): Model type.
"""
opt = deepcopy(opt)
metric_type = opt.pop('type')
metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
return metric
================================================
FILE: basicsr/metrics/metric_util.py
================================================
import numpy as np
from basicsr.utils.matlab_functions import bgr2ycbcr
def reorder_image(img, input_order='HWC'):
"""Reorder images to 'HWC' order.
If the input_order is (h, w), return (h, w, 1);
If the input_order is (c, h, w), return (h, w, c);
If the input_order is (h, w, c), return as it is.
Args:
img (ndarray): Input image.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
If the input image shape is (h, w), input_order will not have
effects. Default: 'HWC'.
Returns:
ndarray: reordered image.
"""
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
if len(img.shape) == 2:
img = img[..., None]
if input_order == 'CHW':
img = img.transpose(1, 2, 0)
return img
def to_y_channel(img):
"""Change to Y channel of YCbCr.
Args:
img (ndarray): Images with range [0, 255].
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
img = img.astype(np.float32) / 255.
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
return img * 255.
================================================
FILE: basicsr/metrics/psnr_ssim.py
================================================
import cv2
import numpy as np
from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.registry import METRIC_REGISTRY
@METRIC_REGISTRY.register()
def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
"""
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
def _ssim(img1, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img1 (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: ssim result.
"""
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
@METRIC_REGISTRY.register()
def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
"""Calculate SSIM (structural similarity).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img1 (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
"""
assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
img1 = reorder_image(img1, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
if crop_border != 0:
img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
if test_y_channel:
img1 = to_y_channel(img1)
img2 = to_y_channel(img2)
ssims = []
for i in range(img1.shape[2]):
ssims.append(_ssim(img1[..., i], img2[..., i]))
return np.array(ssims).mean()
================================================
FILE: basicsr/models/__init__.py
================================================
import importlib
from copy import deepcopy
from os import path as osp
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import MODEL_REGISTRY
__all__ = ['build_model']
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
def build_model(opt):
"""Build model from options.
Args:
opt (dict): Configuration. It must constain:
model_type (str): Model type.
"""
opt = deepcopy(opt)
model = MODEL_REGISTRY.get(opt['model_type'])(opt)
logger = get_root_logger()
logger.info(f'Model [{model.__class__.__name__}] is created.')
return model
================================================
FILE: basicsr/models/base_model.py
================================================
import logging
import os
import torch
from collections import OrderedDict
from copy import deepcopy
from torch.nn.parallel import DataParallel, DistributedDataParallel
from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils.dist_util import master_only
logger = logging.getLogger('basicsr')
class BaseModel():
"""Base model."""
def __init__(self, opt):
self.opt = opt
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
self.is_train = opt['is_train']
self.schedulers = []
self.optimizers = []
def feed_data(self, data):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
pass
def save(self, epoch, current_iter):
"""Save networks and training state."""
pass
def validation(self, dataloader, current_iter, tb_logger, save_img=False):
"""Validation function.
Args:
dataloader (torch.utils.data.DataLoader): Validation dataloader.
current_iter (int): Current iteration.
tb_logger (tensorboard logger): Tensorboard logger.
save_img (bool): Whether to save images. Default: False.
"""
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
else:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def model_ema(self, decay=0.999):
net_g = self.get_bare_model(self.net_g)
net_g_params = dict(net_g.named_parameters())
net_g_ema_params = dict(self.net_g_ema.named_parameters())
for k in net_g_ema_params.keys():
net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
def get_current_log(self):
return self.log_dict
def model_to_device(self, net):
"""Model to device. It also warps models with DistributedDataParallel
or DataParallel.
Args:
net (nn.Module)
"""
net = net.to(self.device)
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
def get_optimizer(self, optim_type, params, lr, **kwargs):
if optim_type == 'Adam':
optimizer = torch.optim.Adam(params, lr, **kwargs)
else:
raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
return optimizer
def setup_schedulers(self):
"""Set up schedulers."""
train_opt = self.opt['train']
scheduler_type = train_opt['scheduler'].pop('type')
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
elif scheduler_type == 'CosineAnnealingRestartLR':
for optimizer in self.optimizers:
self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
else:
raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
def get_bare_model(self, net):
"""Get bare model, especially under wrapping with
DistributedDataParallel or DataParallel.
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net = net.module
return net
@master_only
def print_network(self, net):
"""Print the str and parameter number of a network.
Args:
net (nn.Module)
"""
if isinstance(net, (DataParallel, DistributedDataParallel)):
net_cls_str = (f'{net.__class__.__name__} - ' f'{net.module.__class__.__name__}')
else:
net_cls_str = f'{net.__class__.__name__}'
net = self.get_bare_model(net)
net_str = str(net)
net_params = sum(map(lambda x: x.numel(), net.parameters()))
logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
logger.info(net_str)
def _set_lr(self, lr_groups_l):
"""Set learning rate for warmup.
Args:
lr_groups_l (list): List for lr_groups, each for an optimizer.
"""
for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
for param_group, lr in zip(optimizer.param_groups, lr_groups):
param_group['lr'] = lr
def _get_init_lr(self):
"""Get the initial lr, which is set by the scheduler.
"""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
return init_lr_groups_l
def update_learning_rate(self, current_iter, warmup_iter=-1):
"""Update learning rate.
Args:
current_iter (int): Current iteration.
warmup_iter (int): Warmup iter numbers. -1 for no warmup.
Default: -1.
"""
if current_iter > 1:
for scheduler in self.schedulers:
scheduler.step()
# set up warm-up learning rate
if current_iter < warmup_iter:
# get initial lr for each group
init_lr_g_l = self._get_init_lr()
# modify warming-up learning rates
# currently only support linearly warm up
warm_up_lr_l = []
for init_lr_g in init_lr_g_l:
warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
# set learning rate
self._set_lr(warm_up_lr_l)
def get_current_learning_rate(self):
return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
@master_only
def save_network(self, net, net_label, current_iter, param_key='params'):
"""Save networks.
Args:
net (nn.Module | list[nn.Module]): Network(s) to be saved.
net_label (str): Network label.
current_iter (int): Current iter number.
param_key (str | list[str]): The parameter key(s) to save network.
Default: 'params'.
"""
if current_iter == -1:
current_iter = 'latest'
save_filename = f'{net_label}_{current_iter}.pth'
save_path = os.path.join(self.opt['path']['models'], save_filename)
net = net if isinstance(net, list) else [net]
param_key = param_key if isinstance(param_key, list) else [param_key]
assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
save_dict = {}
for net_, param_key_ in zip(net, param_key):
net_ = self.get_bare_model(net_)
state_dict = net_.state_dict()
for key, param in state_dict.items():
if key.startswith('module.'): # remove unnecessary 'module.'
key = key[7:]
state_dict[key] = param.cpu()
save_dict[param_key_] = state_dict
torch.save(save_dict, save_path)
def _print_different_keys_loading(self, crt_net, load_net, strict=True):
"""Print keys with differnet name or different size when loading models.
1. Print keys with differnet names.
2. If strict=False, print the same key but with different tensor size.
It also ignore these keys with different sizes (not load).
Args:
crt_net (torch model): Current network.
load_net (dict): Loaded network.
strict (bool): Whether strictly loaded. Default: True.
"""
crt_net = self.get_bare_model(crt_net)
crt_net = crt_net.state_dict()
crt_net_keys = set(crt_net.keys())
load_net_keys = set(load_net.keys())
if crt_net_keys != load_net_keys:
logger.warning('Current net - loaded net:')
for v in sorted(list(crt_net_keys - load_net_keys)):
logger.warning(f' {v}')
logger.warning('Loaded net - current net:')
for v in sorted(list(load_net_keys - crt_net_keys)):
logger.warning(f' {v}')
# check the size for the same keys
if not strict:
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
logger.warning(f'Size different, ignore [{k}]: crt_net: '
f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
load_net[k + '.ignore'] = load_net.pop(k)
def load_network(self, net, load_path, strict=True, param_key='params'):
"""Load network.
Args:
load_path (str): The path of networks to be loaded.
net (nn.Module): Network.
strict (bool): Whether strictly loaded.
param_key (str): The parameter key of loaded network. If set to
None, use the root 'path'.
Default: 'params'.
"""
net = self.get_bare_model(net)
logger.info(f'Loading {net.__class__.__name__} model from {load_path}.')
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
if param_key is not None:
if param_key not in load_net and 'params' in load_net:
param_key = 'params'
logger.info('Loading: params_ema does not exist, use params.')
load_net = load_net[param_key]
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
self._print_different_keys_loading(net, load_net, strict)
net.load_state_dict(load_net, strict=strict)
@master_only
def save_training_state(self, epoch, current_iter):
"""Save training states during training, which will be used for
resuming.
Args:
epoch (int): Current epoch.
current_iter (int): Current iteration.
"""
if current_iter != -1:
state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
for o in self.optimizers:
state['optimizers'].append(o.state_dict())
for s in self.schedulers:
state['schedulers'].append(s.state_dict())
save_filename = f'{current_iter}.state'
save_path = os.path.join(self.opt['path']['training_states'], save_filename)
torch.save(state, save_path)
def resume_training(self, resume_state):
"""Reload the optimizers and schedulers for resumed training.
Args:
resume_state (dict): Resume state.
"""
resume_optimizers = resume_state['optimizers']
resume_schedulers = resume_state['schedulers']
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
for i, o in enumerate(resume_optimizers):
self.optimizers[i].load_state_dict(o)
for i, s in enumerate(resume_schedulers):
self.schedulers[i].load_state_dict(s)
def reduce_loss_dict(self, loss_dict):
"""reduce loss dict.
In distributed training, it averages the losses among different GPUs .
Args:
loss_dict (OrderedDict): Loss dict.
"""
with torch.no_grad():
if self.opt['dist']:
keys = []
losses = []
for name, value in loss_dict.items():
keys.append(name)
losses.append(value)
losses = torch.stack(losses, 0)
torch.distributed.reduce(losses, dst=0)
if self.opt['rank'] == 0:
losses /= self.opt['world_size']
loss_dict = {key: loss for key, loss in zip(keys, losses)}
log_dict = OrderedDict()
for name, value in loss_dict.items():
log_dict[name] = value.mean().item()
return log_dict
================================================
FILE: basicsr/models/codeformer_dirichlet_video_model.py
================================================
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from einops import rearrange
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img, tensor2imgs, images_to_gif
from basicsr.utils.registry import MODEL_REGISTRY
import torch.nn.functional as F
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class CodeFormerDirichletVideoModel(SRModel):
def feed_data(self, data):
self.gt = data['gt'].to(self.device) # b t c h w
self.input = data['in'].to(self.device)
self.lq = data['in'].to(self.device)
self.input_large_de = data['in'].to(self.device)
self.b, self.t = data['gt'].shape[:2]
# self.input_large_de = data['in_large_de'].to(self.device)
# merge b t
self.gt = rearrange(self.gt, "b t ... -> (b t) ...")
self.input = rearrange(self.input, "b t ... -> (b t) ...")
self.input_large_de = rearrange(self.input_large_de, "b t ... -> (b t) ...")
if 'latent_gt' in data:
self.idx_gt = data['latent_gt'].to(self.device)
# self.idx_gt = self.idx_gt.view(self.b, -1)
self.idx_gt = rearrange(self.idx_gt, "b t ... -> (b t) ...")
else:
self.idx_gt = None
def init_training_settings(self):
logger = get_root_logger()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
self.scale_adaptive_gan_weight = train_opt.get('scale_adaptive_gan_weight', 0.8)
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
self.net_g.train()
self.net_d.train()
# define losses
self.cri_pix = None
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
self.cri_perceptual = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
# add the dir dist KL loss
self.cri_dirichletKL = None
if train_opt.get('dirichletKL_opt'):
self.cri_dirichletKL = build_loss(train_opt['dirichletKL_opt']).to(self.device)
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
self.fix_generator = train_opt.get('fix_generator', True)
logger.info(f'fix_generator: {self.fix_generator}')
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
return d_weight
def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
optim_params_g = []
trainable_modules = []
notrainable_modules = []
for k, v in self.net_g.named_parameters():
module_ = '.'.join(k.split('.')[:2])
if v.requires_grad:
optim_params_g.append(v)
if module_ not in trainable_modules:
trainable_modules.append(module_)
else:
if module_ not in notrainable_modules:
notrainable_modules.append(module_)
logger = get_root_logger()
for _ in trainable_modules:
logger.warning(f'{_} will be optimized.')
for _ in notrainable_modules:
logger.warning(f'{_} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def gray_resize_for_identity(self, out, size=128):
out_gray = (0.2989 * out[:, 0, :, :] + 0.5870 * out[:, 1, :, :] + 0.1140 * out[:, 2, :, :])
out_gray = out_gray.unsqueeze(1)
out_gray = F.interpolate(out_gray, (size, size), mode='bilinear', align_corners=False)
return out_gray
def optimize_parameters(self, current_iter):
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
large_de = False
self.output, lq_feat, dirichletDistParam = self.net_g(self.input, w=1.0, detach_16=True)
l_g_total = 0
loss_dict = OrderedDict()
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
if not large_de: # when large degradation don't need image-level loss
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep = self.cri_perceptual(self.output, self.gt)
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
if self.cri_dirichletKL:
l_g_dirKL = self.cri_dirichletKL(dirichletDistParam)
l_g_total += l_g_dirKL
loss_dict['l_g_dirichletKL'] = l_g_dirKL
# gan loss
if current_iter > self.net_d_start_iter:
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
recon_loss = l_g_pix + l_g_percep
loss_dict['recon_loss'] = recon_loss
loss_dict['l_g_gan'] = 0.1 * l_g_gan
l_g_total += recon_loss
l_g_total += l_g_gan
l_g_total.backward()
for name, param in self.net_g.named_parameters():
if not param.requires_grad:
continue
self.optimizer_g.step()
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
# optimize net_d
if not large_de:
if current_iter > self.net_d_start_iter:
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(self.gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
def test(self):
with torch.no_grad():
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
self.output, _, _ = self.net_g_ema(self.input, w=1)
else:
logger = get_root_logger()
logger.warning('Do not have self.net_g_ema, use self.net_g.')
self.net_g.eval()
self.output, _, _ = self.net_g(self.input, w=1)
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = val_data["key"][0].split('/')[-3]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']], min_max=(-1, 1))
sr_imgs = tensor2imgs(visuals['result'], min_max=(-1, 1))
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']], min_max=(-1, 1))
gt_imgs = tensor2imgs(visuals['gt'], min_max=(-1, 1))
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
save_img_gif_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.gif')
save_img_path_ori = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}_gt.png')
save_img_gif_path_ori = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}_gt.gif')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
save_img_gif_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.gif')
save_img_path_ori = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}_ori.png')
save_img_gif_path_ori = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}_ori.gif')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
save_img_gif_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.gif')
save_img_path_ori = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}_ori.png')
save_img_gif_path_ori = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}_ori.gif')
imwrite(sr_img, save_img_path)
imwrite(gt_img, save_img_path_ori)
images_to_gif(sr_imgs, save_img_gif_path, duration = 50, loop=4)
images_to_gif(gt_imgs, save_img_gif_path_ori, duration = 50, loop=4)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['gt'] = self.gt.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
if self.ema_decay > 0:
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
================================================
FILE: basicsr/models/lr_scheduler.py
================================================
import math
from collections import Counter
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepRestartLR(_LRScheduler):
""" MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
milestones (list): Iterations that will decrease learning rate.
gamma (float): Decrease ratio. Default: 0.1.
restarts (list): Restart iterations. Default: [0].
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.restarts = restarts
self.restart_weights = restart_weights
assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch in self.restarts:
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
def get_position_from_periods(iteration, cumulative_period):
"""Get the position from a period list.
It will return the index of the right-closest number in the period list.
For example, the cumulative_period = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 2.
Args:
iteration (int): Current iteration.
cumulative_period (list[int]): Cumulative period list.
Returns:
int: The position of the right-closest number in the period list.
"""
for i, period in enumerate(cumulative_period):
if iteration <= period:
return i
class CosineAnnealingRestartLR(_LRScheduler):
""" Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
restart_weights = [1, 0.5, 0.5, 0.5]
eta_min=1e-7
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
scheduler will restart with the weights in restart_weights.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
periods (list): Period for each cosine anneling cycle.
restart_weights (list): Restart weights at each restart iteration.
Default: [1].
eta_min (float): The mimimum lr. Default: 0.
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
self.periods = periods
self.restart_weights = restart_weights
self.eta_min = eta_min
assert (len(self.periods) == len(
self.restart_weights)), 'periods and restart_weights should have the same length.'
self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
current_period = self.periods[idx]
return [
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
(1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
for base_lr in self.base_lrs
]
================================================
FILE: basicsr/models/sr_model.py
================================================
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
@MODEL_REGISTRY.register()
class SRModel(BaseModel):
"""Base SR model for single image super-resolution."""
def __init__(self, opt):
super(SRModel, self).__init__(opt)
# define network
self.net_g = build_network(opt['network_g'])
self.net_g = self.model_to_device(self.net_g)
self.print_network(self.net_g)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
param_key = self.opt['path'].get('param_key_g', 'params')
self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
if self.is_train:
self.init_training_settings()
def init_training_settings(self):
self.net_g.train()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger = get_root_logger()
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if self.cri_pix is None and self.cri_perceptual is None:
raise ValueError('Both pixel and perceptual losses are None.')
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def setup_optimizers(self):
train_opt = self.opt['train']
optim_params = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def feed_data(self, data):
self.lq = data['lq'].to(self.device)
if 'gt' in data:
self.gt = data['gt'].to(self.device)
def optimize_parameters(self, current_iter):
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_total = 0
loss_dict = OrderedDict()
# pixel loss
if self.cri_pix:
l_pix = self.cri_pix(self.output, self.gt)
l_total += l_pix
loss_dict['l_pix'] = l_pix
# perceptual loss
if self.cri_perceptual:
l_percep, l_style = self.cri_perceptual(self.output, self.gt)
if l_percep is not None:
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style is not None:
l_total += l_style
loss_dict['l_style'] = l_style
l_total.backward()
self.optimizer_g.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def test(self):
if hasattr(self, 'ema_decay'):
self.net_g_ema.eval()
with torch.no_grad():
self.output = self.net_g_ema(self.lq)
else:
self.net_g.eval()
with torch.no_grad():
self.output = self.net_g(self.lq)
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['lq'] = self.lq.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
if hasattr(self, 'gt'):
out_dict['gt'] = self.gt.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
if hasattr(self, 'ema_decay'):
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_training_state(epoch, current_iter)
================================================
FILE: basicsr/models/vqgan_model.py
================================================
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
import torch.nn.functional as F
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class VQGANModel(SRModel):
def feed_data(self, data):
self.gt = data['gt'].to(self.device)
self.b = self.gt.shape[0]
def init_training_settings(self):
logger = get_root_logger()
train_opt = self.opt['train']
self.ema_decay = train_opt.get('ema_decay', 0)
if self.ema_decay > 0:
logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
# define network net_g with Exponential Moving Average (EMA)
# net_g_ema is used only for testing on one GPU and saving
# There is no need to wrap with DistributedDataParallel
self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
# load pretrained model
load_path = self.opt['path'].get('pretrain_network_g', None)
if load_path is not None:
self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
else:
self.model_ema(0) # copy net_g weight
self.net_g_ema.eval()
# define network net_d
self.net_d = build_network(self.opt['network_d'])
self.net_d = self.model_to_device(self.net_d)
self.print_network(self.net_d)
# load pretrained models
load_path = self.opt['path'].get('pretrain_network_d', None)
if load_path is not None:
self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True))
self.net_g.train()
self.net_d.train()
# define losses
if train_opt.get('pixel_opt'):
self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
else:
self.cri_pix = None
if train_opt.get('perceptual_opt'):
self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)
else:
self.cri_perceptual = None
if train_opt.get('gan_opt'):
self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)
if train_opt.get('codebook_opt'):
self.l_weight_codebook = train_opt['codebook_opt'].get('loss_weight', 1.0)
else:
self.l_weight_codebook = 1.0
self.vqgan_quantizer = self.opt['network_g']['quantizer']
logger.info(f'vqgan_quantizer: {self.vqgan_quantizer}')
self.net_g_start_iter = train_opt.get('net_g_start_iter', 0)
self.net_d_iters = train_opt.get('net_d_iters', 1)
self.net_d_start_iter = train_opt.get('net_d_start_iter', 0)
self.disc_weight = train_opt.get('disc_weight', 0.8)
# set up optimizers and schedulers
self.setup_optimizers()
self.setup_schedulers()
def calculate_adaptive_weight(self, recon_loss, g_loss, last_layer, disc_weight_max):
recon_grads = torch.autograd.grad(recon_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
d_weight = torch.norm(recon_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, disc_weight_max).detach()
return d_weight
def adopt_weight(self, weight, global_step, threshold=0, value=0.):
if global_step < threshold:
weight = value
return weight
def setup_optimizers(self):
train_opt = self.opt['train']
# optimizer g
optim_params_g = []
for k, v in self.net_g.named_parameters():
if v.requires_grad:
optim_params_g.append(v)
else:
logger = get_root_logger()
logger.warning(f'Params {k} will not be optimized.')
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
# optimizer d
optim_type = train_opt['optim_d'].pop('type')
self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])
self.optimizers.append(self.optimizer_d)
def optimize_parameters(self, current_iter):
logger = get_root_logger()
loss_dict = OrderedDict()
if self.opt['network_g']['quantizer'] == 'gumbel':
self.net_g.module.quantize.temperature = max(1/16, ((-1/160000) * current_iter) + 1)
if current_iter%1000 == 0:
logger.info(f'temperature: {self.net_g.module.quantize.temperature}')
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output, l_codebook, quant_stats = self.net_g(self.gt)
l_codebook = l_codebook*self.l_weight_codebook
l_g_total = 0
if current_iter % self.net_d_iters == 0 and current_iter > self.net_g_start_iter:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
l_g_total += l_g_pix
loss_dict['l_g_pix'] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep = self.cri_perceptual(self.output, self.gt)
l_g_total += l_g_percep
loss_dict['l_g_percep'] = l_g_percep
# gan loss
if current_iter > self.net_d_start_iter:
# fake_g_pred = self.net_d(self.output_1024)
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
recon_loss = l_g_total
last_layer = self.net_g.module.generator.blocks[-1].weight
d_weight = self.calculate_adaptive_weight(recon_loss, l_g_gan, last_layer, disc_weight_max=1.0)
d_weight *= self.adopt_weight(1, current_iter, self.net_d_start_iter)
d_weight *= self.disc_weight # tamming setting 0.8
l_g_total += d_weight * l_g_gan
loss_dict['l_g_gan'] = d_weight * l_g_gan
l_g_total += l_codebook
loss_dict['l_codebook'] = l_codebook
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
if current_iter > self.net_d_start_iter:
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(self.gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict['l_d_real'] = l_d_real
loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach())
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict['l_d_fake'] = l_d_fake
loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
self.log_dict = self.reduce_loss_dict(loss_dict)
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
def test(self):
with torch.no_grad():
if hasattr(self, 'net_g_ema'):
self.net_g_ema.eval()
self.output, _, _ = self.net_g_ema(self.gt)
else:
logger = get_root_logger()
logger.warning('Do not have self.net_g_ema, use self.net_g.')
self.net_g.eval()
self.output, _, _ = self.net_g(self.gt)
self.net_g.train()
def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
if self.opt['rank'] == 0:
self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
dataset_name = dataloader.dataset.opt['name']
with_metrics = self.opt['val'].get('metrics') is not None
if with_metrics:
self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
pbar = tqdm(total=len(dataloader), unit='image')
for idx, val_data in enumerate(dataloader):
img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
self.feed_data(val_data)
self.test()
visuals = self.get_current_visuals()
sr_img = tensor2img([visuals['result']])
if 'gt' in visuals:
gt_img = tensor2img([visuals['gt']])
del self.gt
# tentative for out of GPU memory
del self.lq
del self.output
torch.cuda.empty_cache()
if save_img:
if self.opt['is_train']:
save_img_path = osp.join(self.opt['path']['visualization'], img_name,
f'{img_name}_{current_iter}.png')
else:
if self.opt['val']['suffix']:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["val"]["suffix"]}.png')
else:
save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
f'{img_name}_{self.opt["name"]}.png')
imwrite(sr_img, save_img_path)
if with_metrics:
# calculate metrics
for name, opt_ in self.opt['val']['metrics'].items():
metric_data = dict(img1=sr_img, img2=gt_img)
self.metric_results[name] += calculate_metric(metric_data, opt_)
pbar.update(1)
pbar.set_description(f'Test {img_name}')
pbar.close()
if with_metrics:
for metric in self.metric_results.keys():
self.metric_results[metric] /= (idx + 1)
self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
log_str = f'Validation {dataset_name}\n'
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}\n'
logger = get_root_logger()
logger.info(log_str)
if tb_logger:
for metric, value in self.metric_results.items():
tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
def get_current_visuals(self):
out_dict = OrderedDict()
out_dict['gt'] = self.gt.detach().cpu()
out_dict['result'] = self.output.detach().cpu()
return out_dict
def save(self, epoch, current_iter):
if self.ema_decay > 0:
self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
else:
self.save_network(self.net_g, 'net_g', current_iter)
self.save_network(self.net_d, 'net_d', current_iter)
self.save_training_state(epoch, current_iter)
================================================
FILE: basicsr/ops/__init__.py
================================================
================================================
FILE: basicsr/ops/dcn/__init__.py
================================================
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
modulated_deform_conv)
__all__ = [
'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
'modulated_deform_conv'
]
================================================
FILE: basicsr/ops/dcn/deform_conv.py
================================================
import math
import torch
from torch import nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn import functional as F
from torch.nn.modules.utils import _pair, _single
try:
from . import deform_conv_ext
except ImportError:
import os
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
deform_conv_ext = load(
'deform_conv',
sources=[
os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
],
)
class DeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
weight,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
im2col_step=64):
if input is not None and input.dim() != 4:
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
ctx.stride = _pair(stride)
ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation)
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.im2col_step = im2col_step
ctx.save_for_backward(input, offset, weight)
output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
if not input.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
deform_conv_ext.deform_conv_forward(input, weight,
offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
input, offset, weight = ctx.saved_tensors
grad_input = grad_offset = grad_weight = None
if not grad_output.is_cuda:
raise NotImplementedError
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
grad_offset, weight, ctx.bufs_[0], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
ctx.deformable_groups, cur_im2col_step)
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
weight.size(2), ctx.stride[1], ctx.stride[0],
ctx.padding[1], ctx.padding[0], ctx.dilation[1],
ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
cur_im2col_step)
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@staticmethod
def _output_size(input, weight, padding, dilation, stride):
channels = weight.size(0)
output_size = (input.size(0), channels)
for d in range(input.dim() - 2):
in_size = input.size(d + 2)
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
if not all(map(lambda s: s > 0, output_size)):
raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
return output_size
class ModulatedDeformConvFunction(Function):
@staticmethod
def forward(ctx,
input,
offset,
mask,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.deformable_groups = deformable_groups
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(1) # fake tensor
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = ctx.saved_tensors
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
grad_output, weight.shape[2], weight.shape[3], ctx.stride,
ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
ctx.groups, ctx.deformable_groups, ctx.with_bias)
if not ctx.with_bias:
grad_bias = None
return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
@staticmethod
def _infer_shape(ctx, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
return n, channels_out, height_out, width_out
deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply
class DeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=False):
super(DeformConv, self).__init__()
assert not bias
assert in_channels % groups == 0, \
f'in_channels {in_channels} is not divisible by groups {groups}'
assert out_channels % groups == 0, \
f'out_channels {out_channels} is not divisible ' \
f'by groups {groups}'
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = _pair(stride)
self.padding = _pair(padding)
self.dilation = _pair(dilation)
self.groups = groups
self.deformable_groups = deformable_groups
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
self.reset_parameters()
def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
if input_pad:
out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
return out
class DeformConvPack(DeformConv):
"""A Deformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(DeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_offset()
def init_offset(self):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
offset = self.conv_offset(x)
return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
self.deformable_groups)
class ModulatedDeformConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
deformable_groups=1,
bias=True):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.deformable_groups = deformable_groups
self.with_bias = bias
# enable compatibility with nn.Conv2d
self.transposed = False
self.output_padding = _single(0)
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.init_weights()
def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
class ModulatedDeformConvPack(ModulatedDeformConv):
"""A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
bias (bool or str): If specified as `auto`, it will be decided by the
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
False.
"""
_version = 2
def __init__(self, *args, **kwargs):
super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d(
self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
bias=True)
self.init_weights()
def init_weights(self):
super(ModulatedDeformConvPack, self).init_weights()
if hasattr(self, 'conv_offset'):
self.conv_offset.weight.data.zero_()
self.conv_offset.bias.data.zero_()
def forward(self, x):
out = self.conv_offset(x)
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
self.groups, self.deformable_groups)
================================================
FILE: basicsr/ops/dcn/src/deform_conv_cuda.cpp
================================================
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include
#include
#include
#include
void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor data_col);
void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
const int channels, const int height, const int width,
const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im);
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const int channels, const int height,
const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor grad_offset);
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor data_col);
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset,
const at::Tensor data_mask, const int batch_size, const int channels,
const int height_im, const int width_im, const int height_col,
const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int deformable_group,
at::Tensor grad_im);
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im,
const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im,
const int width_im, const int height_col, const int width_col,
const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w, const int dilation_h,
const int dilation_w, const int deformable_group, at::Tensor grad_offset,
at::Tensor grad_mask);
void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
int padW, int dilationH, int dilationW, int group,
int deformable_group) {
TORCH_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4) {
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
TORCH_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
TORCH_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");
if (gradOutput != NULL) {
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
outputHeight, outputWidth, gradOutput->size(dimh),
gradOutput->size(dimw));
}
}
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
// todo: resize columns to include im2col: done
// todo: add im2col_step as input
// todo: add new output buffer and transpose it to output (or directly
// transpose output) todo: possibly change data indexing because of
// parallel_imgs
shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input.unsqueeze_(0);
offset.unsqueeze_(0);
}
// todo: assert batchsize dividable by im2col_step
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
ones = at::ones({outputHeight, outputWidth}, input.options());
}
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
at::Tensor output_buffer =
at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
weight = weight.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view({1, input.size(0), input.size(1), input.size(2)});
offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
// change order of grad output
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight,
outputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
// divide into groups
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
gradOutput = gradOutput.view(
{gradOutput.size(0), group, gradOutput.size(1) / group,
gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
for (int g = 0; g < group; g++) {
columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradOutput = gradOutput.view(
{gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
dilationH, dilationW, im2col_step, deformable_group,
gradOffset[elt]);
deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, gradInput[elt]);
}
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
gradOffset = gradOffset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
gradOffset =
gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
}
return 1;
}
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
// todo: transpose and reshape outGrad
// todo: reshape columns
// todo: add im2col_step as input
shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group, deformable_group);
at::DeviceGuard guard(input.device());
input = input.contiguous();
offset = offset.contiguous();
gradOutput = gradOutput.contiguous();
int batch = 1;
if (input.ndimension() == 3) {
// Force batch
batch = 0;
input = input.view(
at::IntList({1, input.size(0), input.size(1), input.size(2)}));
gradOutput = gradOutput.view(
{1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = gradWeight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
nOutputPlane, outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
outputHeight, outputWidth});
gradOutputBuffer.copy_(gradOutput);
gradOutputBuffer =
gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth});
gradOutput.transpose_(1, 2);
gradOutput =
gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
offset =
offset.view({batchSize / im2col_step, im2col_step,
deformable_group * 2 * kH * kW, outputHeight, outputWidth});
for (int elt = 0; elt < batchSize / im2col_step; elt++) {
deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, deformable_group, columns);
// divide into group
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
gradWeight =
gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3)});
for (int g = 0; g < group; g++) {
gradWeight[g] = gradWeight[g]
.flatten(1)
.addmm_(gradOutputBuffer[elt][g].flatten(1),
columns[g].transpose(1, 0), 1.0, scale)
.view_as(gradWeight[g]);
}
gradOutputBuffer = gradOutputBuffer.view(
{gradOutputBuffer.size(0),
gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
gradWeight.size(2), gradWeight.size(3),
gradWeight.size(4)});
}
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
offset = offset.view(
{batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
if (batch == 0) {
gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
return 1;
}
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_out = weight.size(0);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
// resize output
output = output.view({batch, channels_out, height_out, width_out}).zero_();
// resize temporary columns
columns =
at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
input.options());
output = output.view({output.size(0), group, output.size(1) / group,
output.size(2), output.size(3)});
for (int b = 0; b < batch; b++) {
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
// divide into group
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
for (int g = 0; g < group; g++) {
output[b][g] = output[b][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output[b][g]);
}
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
output = output.view({output.size(0), output.size(1) * output.size(2),
output.size(3), output.size(4)});
if (with_bias) {
output += bias.view({1, bias.size(0), 1, 1});
}
}
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
at::DeviceGuard guard(input.device());
const int batch = input.size(0);
const int channels = input.size(1);
const int height = input.size(2);
const int width = input.size(3);
const int channels_kernel = weight.size(1);
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
channels, channels_kernel * group);
const int height_out =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
if (ones.ndimension() != 2 ||
ones.size(0) * ones.size(1) < height_out * width_out) {
// Resize plane and fill with ones...
ones = at::ones({height_out, width_out}, input.options());
}
grad_input = grad_input.view({batch, channels, height, width});
columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
input.options());
grad_output =
grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
grad_output.size(2), grad_output.size(3)});
for (int b = 0; b < batch; b++) {
// divide int group
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++) {
columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
grad_output[b][g].flatten(1), 0.0f, 1.0f);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(
columns, input[b], offset[b], mask[b], 1, channels, height, width,
height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
grad_mask[b]);
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(
columns, offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, grad_input[b]);
// gradient w.r.t. weight, dWeight should accumulate across the batch and
// group
modulated_deformable_im2col_cuda(
input[b], offset[b], mask[b], 1, channels, height, width, height_out,
width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group, columns);
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
grad_weight.size(1), grad_weight.size(2),
grad_weight.size(3)});
if (with_bias)
grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
for (int g = 0; g < group; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
.view_as(grad_weight[g]);
if (with_bias) {
grad_bias[g] =
grad_bias[g]
.view({-1, 1})
.addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
.view(-1);
}
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2), grad_weight.size(3),
grad_weight.size(4)});
if (with_bias)
grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
}
grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
grad_output.size(2), grad_output.size(3),
grad_output.size(4)});
}
================================================
FILE: basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
================================================
/*!
******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
*
* COPYRIGHT
*
* All contributions by the University of California:
* Copyright (c) 2014-2017 The Regents of the University of California (Regents)
* All rights reserved.
*
* All other contributions:
* Copyright (c) 2014-2017, the respective contributors
* All rights reserved.
*
* Caffe uses a shared copyright model: each contributor holds copyright over
* their contributions to Caffe. The project versioning records all such
* contribution and copyright details. If a contributor wants to further mark
* their specific copyright on a particular contribution, they should indicate
* their copyright solely in the commit message of the change when it is
* committed.
*
* LICENSE
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* CONTRIBUTION AGREEMENT
*
* By contributing to the BVLC/caffe repository through pull-request, comment,
* or otherwise, the contributor releases their content to the
* license and copyright terms herein.
*
***************** END Caffe Copyright Notice and Disclaimer ********************
*
* Copyright (c) 2018 Microsoft
* Licensed under The MIT License [see LICENSE for details]
* \file modulated_deformable_im2col.cuh
* \brief Function definitions of converting an image to
* column matrix based on kernel, padding, dilation, and offset.
* These functions are mainly used in deformable convolution operators.
* \ref: https://arxiv.org/abs/1703.06211
* \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
*/
// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
#include
#include
#include
#include
#include
#include
using namespace at;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;
inline int GET_BLOCKS(const int N)
{
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}
template
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t val = static_cast(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const scalar_t map_h = i * dilation_h + offset_h;
//const scalar_t map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
void deformable_im2col(
const at::Tensor data_im, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w, const int parallel_imgs,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
// todo: check parallel_imgs is correctly passed in
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
scalar_t *data_col_ = data_col.data_ptr();
deformable_im2col_gpu_kernel<<>>(
num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
channel_per_deformable_group, parallel_imgs, channels, deformable_group,
height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
}
}
template
__global__ void deformable_col2im_gpu_kernel(
const int n, const scalar_t *data_col, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
2 * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index];
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
void deformable_col2im(
const at::Tensor data_col, const at::Tensor data_offset, const int channels,
const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group,
at::Tensor grad_im)
{
// todo: make sure parallel_imgs is passed in correctly
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
int channel_per_deformable_group = channels / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
scalar_t *grad_im_ = grad_im.data_ptr();
deformable_col2im_gpu_kernel<<>>(
num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
}
}
template
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
const scalar_t *data_im, const scalar_t *data_offset,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col, scalar_t *grad_offset)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
const scalar_t weight = get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos];
cnt += 1;
}
grad_offset[index] = val;
}
}
void deformable_col2im_coord(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
const int channels, const int height, const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
const int stride_w, const int dilation_h, const int dilation_w,
const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{
int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr();
const scalar_t *data_im_ = data_im.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
scalar_t *grad_offset_ = grad_offset.data_ptr();
deformable_col2im_coord_gpu_kernel<<>>(
num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
height_col, width_col, grad_offset_);
}));
}
template
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
const int height, const int width, scalar_t h, scalar_t w)
{
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
scalar_t lh = h - h_low;
scalar_t lw = w - w_low;
scalar_t hh = 1 - lh, hw = 1 - lw;
scalar_t v1 = 0;
if (h_low >= 0 && w_low >= 0)
v1 = bottom_data[h_low * data_width + w_low];
scalar_t v2 = 0;
if (h_low >= 0 && w_high <= width - 1)
v2 = bottom_data[h_low * data_width + w_high];
scalar_t v3 = 0;
if (h_high <= height - 1 && w_low >= 0)
v3 = bottom_data[h_high * data_width + w_low];
scalar_t v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1)
v4 = bottom_data[h_high * data_width + w_high];
scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
const int h, const int w, const int height, const int width)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (h == argmax_h_low && w == argmax_w_low)
weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
if (h == argmax_h_low && w == argmax_w_high)
weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
if (h == argmax_h_high && w == argmax_w_low)
weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
if (h == argmax_h_high && w == argmax_w_high)
weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
return weight;
}
template
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
const int height, const int width, const scalar_t *im_data,
const int data_width, const int bp_dir)
{
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
{
//empty
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
scalar_t weight = 0;
if (bp_dir == 0)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
else if (bp_dir == 1)
{
if (argmax_h_low >= 0 && argmax_w_low >= 0)
weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
}
return weight;
}
template
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int num_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *data_col)
{
CUDA_KERNEL_LOOP(index, n)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
// compute deformable group index
const int deformable_group_index = c_im / channel_per_deformable_group;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
//const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t val = static_cast(0);
const scalar_t h_im = h_in + i * dilation_h + offset_h;
const scalar_t w_im = w_in + j * dilation_w + offset_w;
//if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
//const float map_h = i * dilation_h + offset_h;
//const float map_w = j * dilation_w + offset_w;
//const int cur_height = height - h_in;
//const int cur_width = width - w_in;
//val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
}
*data_col_ptr = val * mask;
data_col_ptr += batch_size * height_col * width_col;
//data_col_ptr += height_col * width_col;
}
}
}
}
template
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_im)
{
CUDA_KERNEL_LOOP(index, n)
{
const int j = (index / width_col / height_col / batch_size) % kernel_w;
const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
// compute the start and end of the output
const int deformable_group_index = c / channel_per_deformable_group;
int w_out = index % width_col;
int h_out = (index / width_col) % height_col;
int b = (index / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
const scalar_t cur_top_grad = data_col[index] * mask;
const int cur_h = (int)cur_inv_h_data;
const int cur_w = (int)cur_inv_w_data;
for (int dy = -2; dy <= 2; dy++)
{
for (int dx = -2; dx <= 2; dx++)
{
if (cur_h + dy >= 0 && cur_h + dy < height &&
cur_w + dx >= 0 && cur_w + dx < width &&
abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1)
{
int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
}
}
}
}
}
template
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
const scalar_t *data_col, const scalar_t *data_im,
const scalar_t *data_offset, const scalar_t *data_mask,
const int channels, const int height, const int width,
const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int channel_per_deformable_group,
const int batch_size, const int offset_channels, const int deformable_group,
const int height_col, const int width_col,
scalar_t *grad_offset, scalar_t *grad_mask)
{
CUDA_KERNEL_LOOP(index, n)
{
scalar_t val = 0, mval = 0;
int w = index % width_col;
int h = (index / width_col) % height_col;
int c = (index / width_col / height_col) % offset_channels;
int b = (index / width_col / height_col) / offset_channels;
// compute the start and end of the output
const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
{
const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
scalar_t inv_h = h_in + i * dilation_h + offset_h;
scalar_t inv_w = w_in + j * dilation_w + offset_w;
if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
{
inv_h = inv_w = -2;
}
else
{
mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
}
const scalar_t weight = dmcn_get_coordinate_weight(
inv_h, inv_w,
height, width, data_im_ptr + cnt * height * width, width, bp_dir);
val += weight * data_col_ptr[col_pos] * mask;
cnt += 1;
}
// KERNEL_ASSIGN(grad_offset[index], offset_req, val);
grad_offset[index] = val;
if (offset_c % 2 == 0)
// KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
}
}
void modulated_deformable_im2col_cuda(
const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor data_col)
{
// num_axes should be smaller than block size
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
const scalar_t *data_im_ = data_im.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
const scalar_t *data_mask_ = data_mask.data_ptr();
scalar_t *data_col_ = data_col.data_ptr();
modulated_deformable_im2col_gpu_kernel<<>>(
num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
batch_size, channels, deformable_group, height_col, width_col, data_col_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_cuda(
const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group, at::Tensor grad_im)
{
const int channel_per_deformable_group = channels / deformable_group;
const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
const scalar_t *data_mask_ = data_mask.data_ptr();
scalar_t *grad_im_ = grad_im.data_ptr();
modulated_deformable_col2im_gpu_kernel<<>>(
num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, deformable_group, height_col, width_col, grad_im_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
}
}
void modulated_deformable_col2im_coord_cuda(
const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
const int batch_size, const int channels, const int height_im, const int width_im,
const int height_col, const int width_col, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int deformable_group,
at::Tensor grad_offset, at::Tensor grad_mask)
{
const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
const scalar_t *data_col_ = data_col.data_ptr();
const scalar_t *data_im_ = data_im.data_ptr();
const scalar_t *data_offset_ = data_offset.data_ptr();
const scalar_t *data_mask_ = data_mask.data_ptr();
scalar_t *grad_offset_ = grad_offset.data_ptr();
scalar_t *grad_mask_ = grad_mask.data_ptr();
modulated_deformable_col2im_coord_gpu_kernel<<>>(
num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, channel_per_deformable_group,
batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
grad_offset_, grad_mask_);
}));
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
{
printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
}
}
================================================
FILE: basicsr/ops/dcn/src/deform_conv_ext.cpp
================================================
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
#include
#include
#include
#include
#define WITH_CUDA // always use cuda
#ifdef WITH_CUDA
int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step);
int deform_conv_backward_parameters_cuda(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step);
void modulated_deform_conv_cuda_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias);
void modulated_deform_conv_cuda_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias);
#endif
int deform_conv_forward(at::Tensor input, at::Tensor weight,
at::Tensor offset, at::Tensor output,
at::Tensor columns, at::Tensor ones, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_forward_cuda(input, weight, offset, output, columns,
ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
deformable_group, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
at::Tensor gradOutput, at::Tensor gradInput,
at::Tensor gradOffset, at::Tensor weight,
at::Tensor columns, int kW, int kH, int dW,
int dH, int padW, int padH, int dilationW,
int dilationH, int group,
int deformable_group, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_input_cuda(input, offset, gradOutput,
gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
dilationW, dilationH, group, deformable_group, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
int deform_conv_backward_parameters(
at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
at::Tensor gradWeight, // at::Tensor gradBias,
at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
int padW, int padH, int dilationW, int dilationH, int group,
int deformable_group, float scale, int im2col_step) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
dilationH, group, deformable_group, scale, im2col_step);
#else
AT_ERROR("deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("deform conv is not implemented on CPU");
}
void modulated_deform_conv_forward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
int kernel_h, int kernel_w, const int stride_h, const int stride_w,
const int pad_h, const int pad_w, const int dilation_h,
const int dilation_w, const int group, const int deformable_group,
const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
offset, mask, output, columns, kernel_h, kernel_w, stride_h,
stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
deformable_group, with_bias);
#else
AT_ERROR("modulated deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("modulated deform conv is not implemented on CPU");
}
void modulated_deform_conv_backward(
at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
at::Tensor offset, at::Tensor mask, at::Tensor columns,
at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
with_bias);
#else
AT_ERROR("modulated deform conv is not compiled with GPU support");
#endif
}
AT_ERROR("modulated deform conv is not implemented on CPU");
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("deform_conv_forward", &deform_conv_forward,
"deform forward");
m.def("deform_conv_backward_input", &deform_conv_backward_input,
"deform_conv_backward_input");
m.def("deform_conv_backward_parameters",
&deform_conv_backward_parameters,
"deform_conv_backward_parameters");
m.def("modulated_deform_conv_forward",
&modulated_deform_conv_forward,
"modulated deform conv forward");
m.def("modulated_deform_conv_backward",
&modulated_deform_conv_backward,
"modulated deform conv backward");
}
================================================
FILE: basicsr/ops/fused_act/__init__.py
================================================
from .fused_act import FusedLeakyReLU, fused_leaky_relu
__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
================================================
FILE: basicsr/ops/fused_act/fused_act.py
================================================
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import torch
from torch import nn
from torch.autograd import Function
try:
from . import fused_act_ext
except ImportError:
import os
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
fused_act_ext = load(
'fused',
sources=[
os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
],
)
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
empty = grad_output.new_empty(0)
grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
dim = [0]
if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))
grad_bias = grad_input.sum(dim).detach()
return grad_input, grad_bias
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
ctx.scale)
return gradgrad_out, None, None, None
class FusedLeakyReLUFunction(Function):
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale
return out
@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
return grad_input, grad_bias, None, None
class FusedLeakyReLU(nn.Module):
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
super().__init__()
self.bias = nn.Parameter(torch.zeros(channel))
self.negative_slope = negative_slope
self.scale = scale
def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
================================================
FILE: basicsr/ops/fused_act/src/fused_bias_act.cpp
================================================
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
#include
torch::Tensor fused_bias_act_op(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer,
int act, int grad, float alpha, float scale);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor fused_bias_act(const torch::Tensor& input,
const torch::Tensor& bias,
const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
CHECK_CUDA(input);
CHECK_CUDA(bias);
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
}
================================================
FILE: basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
================================================
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
template
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
scalar_t zero = 0.0;
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
scalar_t x = p_x[xi];
if (use_bias) {
x += p_b[(xi / step_b) % size_b];
}
scalar_t ref = use_ref ? p_ref[xi] : zero;
scalar_t y;
switch (act * 10 + grad) {
default:
case 10: y = x; break;
case 11: y = x; break;
case 12: y = 0.0; break;
case 30: y = (x > 0.0) ? x : x * alpha; break;
case 31: y = (ref > 0.0) ? x : x * alpha; break;
case 32: y = 0.0; break;
}
out[xi] = y * scale;
}
}
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
int act, int grad, float alpha, float scale) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
auto x = input.contiguous();
auto b = bias.contiguous();
auto ref = refer.contiguous();
int use_bias = b.numel() ? 1 : 0;
int use_ref = ref.numel() ? 1 : 0;
int size_x = x.numel();
int size_b = b.numel();
int step_b = 1;
for (int i = 1 + 1; i < x.dim(); i++) {
step_b *= x.size(i);
}
int loop_x = 4;
int block_size = 4 * 32;
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
auto y = torch::empty_like(x);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
fused_bias_act_kernel<<>>(
y.data_ptr(),
x.data_ptr(),
b.data_ptr(),
ref.data_ptr(),
act,
grad,
alpha,
scale,
loop_x,
size_x,
step_b,
size_b,
use_bias,
use_ref
);
});
return y;
}
================================================
FILE: basicsr/ops/upfirdn2d/__init__.py
================================================
from .upfirdn2d import upfirdn2d
__all__ = ['upfirdn2d']
================================================
FILE: basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
================================================
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
#include
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
int up_x, int up_y, int down_x, int down_y,
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
CHECK_CUDA(input);
CHECK_CUDA(kernel);
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
================================================
FILE: basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
================================================
// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include
#include
#include
#include
#include
#include
#include
static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
int c = a / b;
if (c * b > a) {
c--;
}
return c;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast(*x_p) * static_cast(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
if (out_x < p.out_w & out_y < p.out_h) {
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 1;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
mode = 2;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 3;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 4;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
mode = 5;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
mode = 6;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
switch (mode) {
case 1:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 2:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 3:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 4:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 5:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
case 6:
upfirdn2d_kernel
<<>>(out.data_ptr(),
x.data_ptr(),
k.data_ptr(), p);
break;
default:
upfirdn2d_kernel_large<<>>(
out.data_ptr(), x.data_ptr(),
k.data_ptr(), p);
}
});
return out;
}
================================================
FILE: basicsr/ops/upfirdn2d/upfirdn2d.py
================================================
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
import torch
from torch.autograd import Function
from torch.nn import functional as F
try:
from . import upfirdn2d_ext
except ImportError:
import os
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_ext = load(
'upfirdn2d',
sources=[
os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_ext.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_ext.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
# ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if input.device.type == 'cpu':
out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
else:
out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
return out
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
return out.view(-1, channel, out_h, out_w)
================================================
FILE: basicsr/setup.py
================================================
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import sys
import time
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
from utils.misc import gpu_is_available
version_file = './basicsr/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('./basicsr/VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def make_cuda_ext(name, module, sources, sources_cuda=None):
if sources_cuda is None:
sources_cuda = []
define_macros = []
extra_compile_args = {'cxx': []}
# if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
if gpu_is_available or os.getenv('FORCE_CUDA', '0') == '1':
define_macros += [('WITH_CUDA', None)]
extension = CUDAExtension
extra_compile_args['nvcc'] = [
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
]
sources += sources_cuda
else:
print(f'Compiling {name} without CUDA')
extension = CppExtension
return extension(
name=f'{module}.{name}',
sources=[os.path.join(*module.split('.'), p) for p in sources],
define_macros=define_macros,
extra_compile_args=extra_compile_args)
def get_requirements(filename='requirements.txt'):
with open(os.path.join('.', filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
if '--cuda_ext' in sys.argv:
ext_modules = [
make_cuda_ext(
name='deform_conv_ext',
module='ops.dcn',
sources=['src/deform_conv_ext.cpp'],
sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
make_cuda_ext(
name='fused_act_ext',
module='ops.fused_act',
sources=['src/fused_bias_act.cpp'],
sources_cuda=['src/fused_bias_act_kernel.cu']),
make_cuda_ext(
name='upfirdn2d_ext',
module='ops.upfirdn2d',
sources=['src/upfirdn2d.cpp'],
sources_cuda=['src/upfirdn2d_kernel.cu']),
]
sys.argv.remove('--cuda_ext')
else:
ext_modules = []
write_version_py()
setup(
name='basicsr',
version=get_version(),
description='Open Source Image and Video Super-Resolution Toolbox',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, restoration, super resolution',
url='https://github.com/xinntao/BasicSR',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={'build_ext': BuildExtension},
zip_safe=False)
================================================
FILE: basicsr/train.py
================================================
import argparse
import datetime
import logging
import math
import copy
import random
import time
import torch
from os import path as osp
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"]="expandable_segments:True"
from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
from basicsr.utils.dist_util import get_dist_info, init_dist
from basicsr.utils.options import dict2str, parse
import warnings
# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
warnings.filterwarnings("ignore", category=UserWarning)
def parse_options(root_path, is_train=True):
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument('--rank', type=int, default=0)
args = parser.parse_args()
opt = parse(args.opt, root_path, is_train=is_train)
# distributed settings
if args.launcher == 'none':
opt['dist'] = False
print('Disable distributed.', flush=True)
else:
opt['dist'] = True
if args.launcher == 'slurm' and 'dist_params' in opt:
init_dist(args.launcher, **opt['dist_params'])
else:
init_dist(args.launcher)
opt['rank'], opt['world_size'] = get_dist_info()
# print(opt['rank'], opt['world_size'])
# exit()
# random seed
seed = opt.get('manual_seed')
if seed is None:
seed = random.randint(1, 10000)
opt['manual_seed'] = seed
set_random_seed(seed + opt['rank'])
return opt
def init_loggers(opt):
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
# initialize wandb logger before tensorboard logger to allow proper sync:
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
init_wandb_logger(opt)
tb_logger = None
if opt['logger'].get('use_tb_logger'):
tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
return logger, tb_logger
def create_train_val_dataloader(opt, logger):
# create train and val dataloaders
train_loader, val_loader = None, None
for phase, dataset_opt in opt['datasets'].items():
if phase == 'train':
dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
train_set = build_dataset(dataset_opt)
train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
train_loader = build_dataloader(
train_set,
dataset_opt,
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
seed=opt['manual_seed'])
print(len(train_set), dataset_enlarge_ratio, dataset_opt['batch_size_per_gpu'], opt['world_size'])
num_iter_per_epoch = math.ceil(
len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
logger.info('Training statistics:'
f'\n\tNumber of train images: {len(train_set)}'
f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
f'\n\tWorld size (gpu number): {opt["world_size"]}'
f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
elif phase == 'val':
val_set = build_dataset(dataset_opt)
val_loader = build_dataloader(
val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
else:
raise ValueError(f'Dataset phase {phase} is not recognized.')
return train_loader, train_sampler, val_loader, total_epochs, total_iters
def train_pipeline(root_path):
# parse options, set distributed setting, set ramdom seed
opt = parse_options(root_path, is_train=True)
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
# load resume states if necessary
if opt['path'].get('resume_state'):
device_id = torch.cuda.current_device()
resume_state = torch.load(
opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
else:
resume_state = None
# mkdir for experiments and logger
if resume_state is None:
make_exp_dirs(opt)
if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
mkdir_and_rename(osp.join('tb_logger', opt['name']))
# initialize loggers
logger, tb_logger = init_loggers(opt)
# create train and validation dataloaders
result = create_train_val_dataloader(opt, logger)
train_loader, train_sampler, val_loader, total_epochs, total_iters = result
# create model
if resume_state: # resume training
check_resume(opt, resume_state['iter'])
model = build_model(opt)
model.resume_training(resume_state) # handle optimizers and schedulers
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
start_epoch = resume_state['epoch']
current_iter = resume_state['iter']
else:
model = build_model(opt)
start_epoch = 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
# dataloader prefetcher
prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
if prefetch_mode is None or prefetch_mode == 'cpu':
prefetcher = CPUPrefetcher(train_loader)
elif prefetch_mode == 'cuda':
prefetcher = CUDAPrefetcher(train_loader, opt)
logger.info(f'Use {prefetch_mode} prefetch dataloader')
if opt['datasets']['train'].get('pin_memory') is not True:
raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
else:
raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
# training
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
data_time, iter_time = time.time(), time.time()
start_time = time.time()
for epoch in range(start_epoch, total_epochs + 1):
train_sampler.set_epoch(epoch)
prefetcher.reset()
train_data = prefetcher.next()
while train_data is not None:
data_time = time.time() - data_time
current_iter += 1
if current_iter > total_iters:
break
# update learning rate
model.update_learning_rate(current_iter,
warmup_iter=opt['train'].get('warmup_iter', -1))
# training
model.feed_data(train_data)
model.optimize_parameters(current_iter)
iter_time = time.time() - iter_time
# log
if current_iter % opt['logger']['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': model.get_current_learning_rate()})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
# save models and training states
if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
logger.info('Saving models and training states.')
model.save(epoch, current_iter)
# validation
if opt.get('val') is not None and opt['datasets'].get('val') is not None \
and (current_iter % opt['val']['val_freq'] == 0):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
data_time = time.time()
iter_time = time.time()
train_data = prefetcher.next()
# end of iter
# end of epoch
consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
logger.info(f'End of training. Time consumed: {consumed_time}')
logger.info('Save the latest model.')
model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
if opt.get('val') is not None and opt['datasets'].get('val'):
model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
if tb_logger:
tb_logger.close()
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
================================================
FILE: basicsr/utils/__init__.py
================================================
from .file_client import FileClient
from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, tensor2imgs, images_to_gif
from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
__all__ = [
# file_client.py
'FileClient',
# img_util.py
'img2tensor',
'tensor2img',
'imfrombytes',
'imwrite',
'crop_border',
# logger.py
'MessageLogger',
'init_tb_logger',
'init_wandb_logger',
'get_root_logger',
'get_env_info',
# misc.py
'set_random_seed',
'get_time_str',
'mkdir_and_rename',
'make_exp_dirs',
'scandir',
'check_resume',
'sizeof_fmt',
# new add
'tensor2imgs',
'images_to_gif',
]
================================================
FILE: basicsr/utils/dist_util.py
================================================
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
import functools
import os
import subprocess
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError(f'Invalid launcher type: {launcher}')
def _init_dist_pytorch(backend, **kwargs):
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
print(f'Initializing PyTorch distributed with rank {rank} and {num_gpus} GPUs.')
# exit()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
# specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['RANK'] = str(proc_id)
dist.init_process_group(backend=backend)
def get_dist_info():
if dist.is_available():
initialized = dist.is_initialized()
else:
initialized = False
if initialized:
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank, world_size
def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank, _ = get_dist_info()
if rank == 0:
return func(*args, **kwargs)
return wrapper
================================================
FILE: basicsr/utils/download_util.py
================================================
import math
import os
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
from urllib.parse import urlparse
from .misc import sizeof_fmt
def download_file_from_google_drive(file_id, save_path):
"""Download files from google drive.
Ref:
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
Args:
file_id (str): File id.
save_path (str): Save path.
"""
session = requests.Session()
URL = 'https://docs.google.com/uc?export=download'
params = {'id': file_id}
response = session.get(URL, params=params, stream=True)
token = get_confirm_token(response)
if token:
params['confirm'] = token
response = session.get(URL, params=params, stream=True)
# get file size
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
print(response_file_size)
if 'Content-Range' in response_file_size.headers:
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
else:
file_size = None
save_response_content(response, save_path, file_size)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination, file_size=None, chunk_size=32768):
if file_size is not None:
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
readable_file_size = sizeof_fmt(file_size)
else:
pbar = None
with open(destination, 'wb') as f:
downloaded_size = 0
for chunk in response.iter_content(chunk_size):
downloaded_size += chunk_size
if pbar is not None:
pbar.update(1)
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if pbar is not None:
pbar.close()
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Load file form http url, will download models if necessary.
Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
Args:
url (str): URL to be downloaded.
model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
Default: None.
progress (bool): Whether to show the download progress. Default: True.
file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
Returns:
str: The path to the downloaded file.
"""
if model_dir is None: # use the pytorch hub_dir
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
================================================
FILE: basicsr/utils/file_client.py
================================================
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
from abc import ABCMeta, abstractmethod
class BaseStorageBackend(metaclass=ABCMeta):
"""Abstract class of storage backends.
All backends need to implement two apis: ``get()`` and ``get_text()``.
``get()`` reads the file as a byte stream and ``get_text()`` reads the file
as texts.
"""
@abstractmethod
def get(self, filepath):
pass
@abstractmethod
def get_text(self, filepath):
pass
class MemcachedBackend(BaseStorageBackend):
"""Memcached storage backend.
Attributes:
server_list_cfg (str): Config file for memcached server list.
client_cfg (str): Config file for memcached client.
sys_path (str | None): Additional path to be appended to `sys.path`.
Default: None.
"""
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
sys.path.append(sys_path)
try:
import mc
except ImportError:
raise ImportError('Please install memcached to enable MemcachedBackend.')
self.server_list_cfg = server_list_cfg
self.client_cfg = client_cfg
self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
# mc.pyvector servers as a point which points to a memory cache
self._mc_buffer = mc.pyvector()
def get(self, filepath):
filepath = str(filepath)
import mc
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class HardDiskBackend(BaseStorageBackend):
"""Raw hard disks storage backend."""
def get(self, filepath):
filepath = str(filepath)
with open(filepath, 'rb') as f:
value_buf = f.read()
return value_buf
def get_text(self, filepath):
filepath = str(filepath)
with open(filepath, 'r') as f:
value_buf = f.read()
return value_buf
class LmdbBackend(BaseStorageBackend):
"""Lmdb storage backend.
Args:
db_paths (str | list[str]): Lmdb database paths.
client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
readonly (bool, optional): Lmdb environment parameter. If True,
disallow any write operations. Default: True.
lock (bool, optional): Lmdb environment parameter. If False, when
concurrent access occurs, do not lock the database. Default: False.
readahead (bool, optional): Lmdb environment parameter. If False,
disable the OS filesystem readahead mechanism, which may improve
random read performance when a database is larger than RAM.
Default: False.
Attributes:
db_paths (list): Lmdb database path.
_client (list): A list of several lmdb envs.
"""
def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
try:
import lmdb
except ImportError:
raise ImportError('Please install lmdb to enable LmdbBackend.')
if isinstance(client_keys, str):
client_keys = [client_keys]
if isinstance(db_paths, list):
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
f'but received {len(client_keys)} and {len(self.db_paths)}.')
self._client = {}
for client, path in zip(client_keys, self.db_paths):
self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
def get(self, filepath, client_key):
"""Get values according to the filepath from one lmdb named client_key.
Args:
filepath (str | obj:`Path`): Here, filepath is the lmdb key.
client_key (str): Used for distinguishing differnet lmdb envs.
"""
filepath = str(filepath)
assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
return value_buf
def get_text(self, filepath):
raise NotImplementedError
class FileClient(object):
"""A general file client to access files in different backend.
The client loads a file or text in a specified backend from its path
and return it as a binary file. it can also register other backend
accessor with a given name and backend class.
Attributes:
backend (str): The storage backend type. Options are "disk",
"memcached" and "lmdb".
client (:obj:`BaseStorageBackend`): The backend object.
"""
_backends = {
'disk': HardDiskBackend,
'memcached': MemcachedBackend,
'lmdb': LmdbBackend,
}
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
f' are {list(self._backends.keys())}')
self.backend = backend
self.client = self._backends[backend](**kwargs)
def get(self, filepath, client_key='default'):
# client_key is used only for lmdb, where different fileclients have
# different lmdb environments.
if self.backend == 'lmdb':
return self.client.get(filepath, client_key)
else:
return self.client.get(filepath)
def get_text(self, filepath):
return self.client.get_text(filepath)
================================================
FILE: basicsr/utils/img_util.py
================================================
import cv2
import math
import numpy as np
from PIL import Image
import os
import torch
from torchvision.utils import make_grid
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1:
result = result[0]
return result
def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
"""This implementation is slightly faster than tensor2img.
It now only supports torch tensor with shape (1, c, h, w).
Args:
tensor (Tensor): Now only support torch tensor with (1, c, h, w).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
"""
output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
output = output.type(torch.uint8).cpu().numpy()
if rgb2bgr:
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def tensor2imgs(tensor, rgb2bgr=True, min_max=(0, 1)):
"""Convert a 4D torch tensor to a list of numpy images.
Args:
tensor (Tensor): A 4D torch tensor with shape (B, C, H, W).
rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
min_max (tuple[int]): min and max values for clamp.
Returns:
list: A list of numpy arrays representing images.
"""
# 检查输入是否为 4D 张量
if tensor.dim() != 4:
raise ValueError(f"Input tensor should be 4D (B, C, H, W), but got {tensor.dim()}D tensor.")
num_images = tensor.size(0)
image_list = []
# 遍历批量中的每个图像
for i in range(num_images):
single_image_tensor = tensor[i].unsqueeze(0) # 提取单张图像并添加一个维度以匹配 tensor2img_fast 的输入要求
single_image_np = tensor2img_fast(single_image_tensor, rgb2bgr=rgb2bgr, min_max=min_max)
image_list.append(single_image_np)
return image_list
def imfrombytes(content, flag='color', float32=False):
"""Read an image from bytes.
Args:
content (bytes): Image bytes got from files or other streams.
flag (str): Flags specifying the color type of a loaded image,
candidates are `color`, `grayscale` and `unchanged`.
float32 (bool): Whether to change to float32., If True, will also norm
to [0, 1]. Default: False.
Returns:
ndarray: Loaded image array.
"""
img_np = np.frombuffer(content, np.uint8)
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
img = img.astype(np.float32) / 255.
return img
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
return cv2.imwrite(file_path, img, params)
def images_to_gif(image_list, output_path, duration=100, loop=0):
"""
将包含 numpy.ndarray 类型图像的列表拼接成一个 GIF 动画。
Args:
image_list (list): 包含 numpy.ndarray 对象的列表,代表图像数据。
output_path (str): 输出 GIF 文件的路径。
duration (int): 每一帧的显示时间(毫秒),默认为 100 毫秒。
loop (int): GIF 动画的循环次数,0 表示无限循环,默认为 0。
"""
# 确保 image_list 不为空
if not image_list:
print("图像列表为空,无法创建 GIF。")
return
pil_images = []
for img in image_list:
# 检查图像是否为单通道灰度图
if len(img.shape) == 2:
pil_img = Image.fromarray(img, mode='L')
else:
# 通常 OpenCV 读取的图像是 BGR 格式,而 PIL 使用 RGB 格式,需要转换
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if 'cv2' in globals() else img
pil_img = Image.fromarray(img)
pil_images.append(pil_img)
# 保存为 GIF
pil_images[0].save(
output_path,
save_all=True,
append_images=pil_images[1:],
duration=duration,
loop=loop
)
def crop_border(imgs, crop_border):
"""Crop borders of images.
Args:
imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
crop_border (int): Crop border for each end of height and weight.
Returns:
list[ndarray]: Cropped images.
"""
if crop_border == 0:
return imgs
else:
if isinstance(imgs, list):
return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
else:
return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
================================================
FILE: basicsr/utils/lmdb_util.py
================================================
import cv2
import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
from tqdm import tqdm
def make_lmdb_from_imgs(data_path,
lmdb_path,
img_path_list,
keys,
batch=5000,
compress_level=1,
multiprocessing_read=False,
n_thread=40,
map_size=None):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
example.lmdb
├── data.mdb
├── lock.mdb
├── meta_info.txt
The data.mdb and lock.mdb are standard lmdb files and you can refer to
https://lmdb.readthedocs.io/en/release/ for more details.
The meta_info.txt is a specified txt file to record the meta information
of our datasets. It will be automatically created when preparing
datasets by our provided dataset tools.
Each line in the txt file records 1)image name (with extension),
2)image shape, and 3)compression level, separated by a white space.
For example, the meta information could be:
`000_00000000.png (720,1280,3) 1`, which means:
1) image name (with extension): 000_00000000.png;
2) image shape: (720,1280,3);
3) compression level: 1
We use the image name without extension as the lmdb key.
If `multiprocessing_read` is True, it will read all the images to memory
using multiprocessing. Thus, your server needs to have enough memory.
Args:
data_path (str): Data path for reading images.
lmdb_path (str): Lmdb save path.
img_path_list (str): Image path list.
keys (str): Used for lmdb keys.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
multiprocessing_read (bool): Whether use multiprocessing to read all
the images to memory. Default: False.
n_thread (int): For multiprocessing.
map_size (int | None): Map size for lmdb env. If None, use the
estimated size from images. Default: None
"""
assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
f'but got {len(img_path_list)} and {len(keys)}')
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
if multiprocessing_read:
# read all the images to memory (multiprocessing)
dataset = {} # use dict to keep the order for multiprocessing
shapes = {}
print(f'Read images with multiprocessing, #thread: {n_thread} ...')
pbar = tqdm(total=len(img_path_list), unit='image')
def callback(arg):
"""get the image data and update pbar."""
key, dataset[key], shapes[key] = arg
pbar.update(1)
pbar.set_description(f'Read {key}')
pool = Pool(n_thread)
for path, key in zip(img_path_list, keys):
pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
pool.close()
pool.join()
pbar.close()
print(f'Finish reading {len(img_path_list)} images.')
# create lmdb environment
if map_size is None:
# obtain data size for one image
img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
data_size_per_img = img_byte.nbytes
print('Data size per image is: ', data_size_per_img)
data_size = data_size_per_img * len(img_path_list)
map_size = data_size * 10
env = lmdb.open(lmdb_path, map_size=map_size)
# write data to lmdb
pbar = tqdm(total=len(img_path_list), unit='chunk')
txn = env.begin(write=True)
txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
for idx, (path, key) in enumerate(zip(img_path_list, keys)):
pbar.update(1)
pbar.set_description(f'Write {key}')
key_byte = key.encode('ascii')
if multiprocessing_read:
img_byte = dataset[key]
h, w, c = shapes[key]
else:
_, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
h, w, c = img_shape
txn.put(key_byte, img_byte)
# write meta information
txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
if idx % batch == 0:
txn.commit()
txn = env.begin(write=True)
pbar.close()
txn.commit()
env.close()
txt_file.close()
print('\nFinish writing lmdb.')
def read_img_worker(path, key, compress_level):
"""Read image worker.
Args:
path (str): Image path.
key (str): Image key.
compress_level (int): Compress level when encoding images.
Returns:
str: Image key.
byte: Image byte.
tuple[int]: Image shape.
"""
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
if img.ndim == 2:
h, w = img.shape
c = 1
else:
h, w, c = img.shape
_, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
return (key, img_byte, (h, w, c))
class LmdbMaker():
"""LMDB Maker.
Args:
lmdb_path (str): Lmdb save path.
map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
batch (int): After processing batch images, lmdb commits.
Default: 5000.
compress_level (int): Compress level when encoding images. Default: 1.
"""
def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
if not lmdb_path.endswith('.lmdb'):
raise ValueError("lmdb_path must end with '.lmdb'.")
if osp.exists(lmdb_path):
print(f'Folder {lmdb_path} already exists. Exit.')
sys.exit(1)
self.lmdb_path = lmdb_path
self.batch = batch
self.compress_level = compress_level
self.env = lmdb.open(lmdb_path, map_size=map_size)
self.txn = self.env.begin(write=True)
self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
self.counter = 0
def put(self, img_byte, key, img_shape):
self.counter += 1
key_byte = key.encode('ascii')
self.txn.put(key_byte, img_byte)
# write meta information
h, w, c = img_shape
self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
if self.counter % self.batch == 0:
self.txn.commit()
self.txn = self.env.begin(write=True)
def close(self):
self.txn.commit()
self.env.close()
self.txt_file.close()
================================================
FILE: basicsr/utils/logger.py
================================================
import datetime
import logging
import time
from .dist_util import get_dist_info, master_only
initialized_logger = {}
class MessageLogger():
"""Message logger for printing.
Args:
opt (dict): Config. It contains the following keys:
name (str): Exp name.
logger (dict): Contains 'print_freq' (str) for logger interval.
train (dict): Contains 'total_iter' (int) for total iters.
use_tb_logger (bool): Use tensorboard logger.
start_iter (int): Start iter. Default: 1.
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
"""
def __init__(self, opt, start_iter=1, tb_logger=None):
self.exp_name = opt['name']
self.interval = opt['logger']['print_freq']
self.start_iter = start_iter
self.max_iters = opt['train']['total_iter']
self.use_tb_logger = opt['logger']['use_tb_logger']
self.tb_logger = tb_logger
self.start_time = time.time()
self.logger = get_root_logger()
@master_only
def __call__(self, log_vars):
"""Format logging message.
Args:
log_vars (dict): It contains the following keys:
epoch (int): Epoch number.
iter (int): Current iter.
lrs (list): List for learning rates.
time (float): Iter time.
data_time (float): Data time for each iter.
"""
# epoch, iter, learning rates
epoch = log_vars.pop('epoch')
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
for v in lrs:
message += f'{v:.3e},'
message += ')] '
# time and estimated time
if 'time' in log_vars.keys():
iter_time = log_vars.pop('time')
data_time = log_vars.pop('data_time')
total_time = time.time() - self.start_time
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
message += f'[eta: {eta_str}, '
message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
# other items, especially losses
for k, v in log_vars.items():
message += f'{k}: {v:.4e} '
# tensorboard logger
if self.use_tb_logger:
# if k.startswith('l_'):
# self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
# else:
self.tb_logger.add_scalar(k, v, current_iter)
self.logger.info(message)
@master_only
def init_tb_logger(log_dir):
from torch.utils.tensorboard import SummaryWriter
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@master_only
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
logger = logging.getLogger('basicsr')
project = opt['logger']['wandb']['project']
resume_id = opt['logger']['wandb'].get('resume_id')
if resume_id:
wandb_id = resume_id
resume = 'allow'
logger.warning(f'Resume wandb logger with id={wandb_id}.')
else:
wandb_id = wandb.util.generate_id()
resume = 'never'
wandb_mode = opt['logger']['wandb'].get('mode', 'offline') # tree mode : offline online disabled
wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True, mode=wandb_mode)
logger.info(f'Use wandb logger with id={wandb_id}; project={project}; mode: {wandb_mode}. ')
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
elif log_file is not None:
logger.setLevel(log_level)
# add file handler
# file_handler = logging.FileHandler(log_file, 'w')
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
def get_env_info():
"""Get environment information.
Currently, only log the software version.
"""
import torch
import torchvision
from basicsr.version import __version__
msg = r"""
____ _ _____ ____
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
/ __ |/ __ `// ___// // ___/\__ \ / /_/ /
/ /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
/_____/ \__,_//____//_/ \___//____//_/ |_|
______ __ __ __ __
/ ____/____ ____ ____/ / / / __ __ _____ / /__ / /
/ / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
"""
msg += ('\nVersion Information: '
f'\n\tBasicSR: {__version__}'
f'\n\tPyTorch: {torch.__version__}'
f'\n\tTorchVision: {torchvision.__version__}')
return msg
================================================
FILE: basicsr/utils/matlab_functions.py
================================================
import math
import numpy as np
import torch
def cubic(x):
"""cubic function used for calculate_weights_indices."""
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
(absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
"""Calculate weights and indices, used for imresize function.
Args:
in_length (int): Input length.
out_length (int): Output length.
scale (float): Scale factor.
kernel_width (int): Kernel width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
"""
if (scale < 1) and antialiasing:
# Use a modified kernel (larger kernel width) to simultaneously
# interpolate and antialias
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5 + scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
p = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
out_length, p)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
# apply cubic kernel
if (scale < 1) and antialiasing:
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, p)
# If a column in weights is all zero, get rid of it. only consider the
# first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, p - 2)
weights = weights.narrow(1, 1, p - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, p - 2)
weights = weights.narrow(1, 0, p - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
@torch.no_grad()
def imresize(img, scale, antialiasing=True):
"""imresize function same as MATLAB.
It now only supports bicubic.
The same scale applies for both height and width.
Args:
img (Tensor | Numpy array):
Tensor: Input image with shape (c, h, w), [0, 1] range.
Numpy: Input image with shape (h, w, c), [0, 1] range.
scale (float): Scale factor. The same scale applies for both height
and width.
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
Default: True.
Returns:
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
"""
if type(img).__module__ == np.__name__: # numpy type
numpy_type = True
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
else:
numpy_type = False
in_c, in_h, in_w = img.size()
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
kernel_width = 4
kernel = 'cubic'
# get weights and indices
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
antialiasing)
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
sym_patch = img[:, :sym_len_hs, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
sym_patch = img[:, -sym_len_he:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(in_c, out_h, in_w)
kernel_width = weights_h.size(1)
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
sym_patch = out_1[:, :, :sym_len_ws]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
sym_patch = out_1[:, :, -sym_len_we:]
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(2, inv_idx)
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(in_c, out_h, out_w)
kernel_width = weights_w.size(1)
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
if numpy_type:
out_2 = out_2.numpy().transpose(1, 2, 0)
return out_2
def rgb2ycbcr(img, y_only=False):
"""Convert a RGB image to YCbCr image.
This function produces the same results as Matlab's `rgb2ycbcr` function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
out_img = np.matmul(
img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def bgr2ycbcr(img, y_only=False):
"""Convert a BGR image to YCbCr image.
The bgr version of rgb2ycbcr.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
y_only (bool): Whether to only return Y channel. Default: False.
Returns:
ndarray: The converted YCbCr image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img)
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
out_img = np.matmul(
img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2rgb(img):
"""Convert a YCbCr image to RGB image.
This function produces the same results as Matlab's ycbcr2rgb function.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted RGB image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def ycbcr2bgr(img):
"""Convert a YCbCr image to BGR image.
The bgr version of ycbcr2rgb.
It implements the ITU-R BT.601 conversion for standard-definition
television. See more details in
https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
In OpenCV, it implements a JPEG conversion. See more details in
https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
ndarray: The converted BGR image. The output image has the same type
and range as input image.
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
[0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
def _convert_input_type_range(img):
"""Convert the type and range of the input image.
It converts the input image to np.float32 type and range of [0, 1].
It is mainly used for pre-processing the input image in colorspace
convertion functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The input image. It accepts:
1. np.uint8 type with range [0, 255];
2. np.float32 type with range [0, 1].
Returns:
(ndarray): The converted image with type of np.float32 and range of
[0, 1].
"""
img_type = img.dtype
img = img.astype(np.float32)
if img_type == np.float32:
pass
elif img_type == np.uint8:
img /= 255.
else:
raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
return img
def _convert_output_type_range(img, dst_type):
"""Convert the type and range of the image according to dst_type.
It converts the image to desired type and range. If `dst_type` is np.uint8,
images will be converted to np.uint8 type with range [0, 255]. If
`dst_type` is np.float32, it converts the image to np.float32 type with
range [0, 1].
It is mainly used for post-processing images in colorspace convertion
functions such as rgb2ycbcr and ycbcr2rgb.
Args:
img (ndarray): The image to be converted with np.float32 type and
range [0, 255].
dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
converts the image to np.uint8 type with range [0, 255]. If
dst_type is np.float32, it converts the image to np.float32 type
with range [0, 1].
Returns:
(ndarray): The converted image with desired type and range.
"""
if dst_type not in (np.uint8, np.float32):
raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
if dst_type == np.uint8:
img = img.round()
else:
img /= 255.
return img.astype(dst_type)
================================================
FILE: basicsr/utils/misc.py
================================================
import os
import re
import random
import time
import torch
import numpy as np
from os import path as osp
from .dist_util import master_only
from .logger import get_root_logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def gpu_is_available():
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
elif isinstance(gpu_id, int):
gpu_str = f':{gpu_id}'
else:
raise TypeError('Input should be int value.')
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def mkdir_and_rename(path):
"""mkdirs. If path exists, rename it with timestamp and create a new one.
Args:
path (str): Folder path.
"""
if osp.exists(path):
new_name = path + '_archived_' + get_time_str()
print(f'Path already exists. Rename it to {new_name}', flush=True)
os.rename(path, new_name)
os.makedirs(path, exist_ok=True)
@master_only
def make_exp_dirs(opt):
"""Make dirs for experiments."""
path_opt = opt['path'].copy()
if opt['is_train']:
mkdir_and_rename(path_opt.pop('experiments_root'))
else:
mkdir_and_rename(path_opt.pop('results_root'))
for key, path in path_opt.items():
if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
os.makedirs(path, exist_ok=True)
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def check_resume(opt, resume_iter):
"""Check resume states and pretrain_network paths.
Args:
opt (dict): Options.
resume_iter (int): Resume iteration.
"""
logger = get_root_logger()
if opt['path']['resume_state']:
# get all the networks
networks = [key for key in opt.keys() if key.startswith('network_')]
flag_pretrain = False
for network in networks:
if opt['path'].get(f'pretrain_{network}') is not None:
flag_pretrain = True
if flag_pretrain:
logger.warning('pretrain_network path will be ignored during resuming.')
# set pretrained model paths
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
if opt['path'].get('ignore_resume_networks') is None or (basename
not in opt['path']['ignore_resume_networks']):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
logger.info(f"Set {name} to {opt['path'][name]}")
def sizeof_fmt(size, suffix='B'):
"""Get human readable file size.
Args:
size (int): File size.
suffix (str): Suffix. Default: 'B'.
Return:
str: Formated file siz.
"""
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
if abs(size) < 1024.0:
return f'{size:3.1f} {unit}{suffix}'
size /= 1024.0
return f'{size:3.1f} Y{suffix}'
================================================
FILE: basicsr/utils/options.py
================================================
import yaml
import time
from collections import OrderedDict
from os import path as osp
from basicsr.utils.misc import get_time_str
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def parse(opt_path, root_path, is_train=True):
"""Parse option file.
Args:
opt_path (str): Option file path.
is_train (str): Indicate whether in training or not. Default: True.
Returns:
(dict): Options.
"""
with open(opt_path, mode='r') as f:
Loader, _ = ordered_yaml()
opt = yaml.load(f, Loader=Loader)
opt['is_train'] = is_train
# opt['name'] = f"{get_time_str()}_{opt['name']}"
if opt['path'].get('resume_state', None): # Shangchen added
resume_state_path = opt['path'].get('resume_state')
opt['name'] = resume_state_path.split("/")[-3]
else:
opt['name'] = f"{get_time_str()}_{opt['name']}"
# datasets
for phase, dataset in opt['datasets'].items():
# for several datasets, e.g., test_1, test_2
phase = phase.split('_')[0]
dataset['phase'] = phase
if 'scale' in opt:
dataset['scale'] = opt['scale']
if dataset.get('dataroot_gt') is not None:
dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
if dataset.get('dataroot_lq') is not None:
dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
# paths
for key, val in opt['path'].items():
if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
opt['path'][key] = osp.expanduser(val)
if is_train:
experiments_root = osp.join(root_path, 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
opt['path']['models'] = osp.join(experiments_root, 'models')
opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
opt['path']['log'] = experiments_root
opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
else: # test
results_root = osp.join(root_path, 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root
opt['path']['visualization'] = osp.join(results_root, 'visualization')
return opt
def dict2str(opt, indent_level=1):
"""dict to string for printing options.
Args:
opt (dict): Option dict.
indent_level (int): Indent level. Default: 1.
Return:
(str): Option string for printing.
"""
msg = '\n'
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_level * 2) + k + ':['
msg += dict2str(v, indent_level + 1)
msg += ' ' * (indent_level * 2) + ']\n'
else:
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
return msg
================================================
FILE: basicsr/utils/realesrgan_utils.py
================================================
import cv2
import math
import numpy as np
import os
import queue
import threading
import torch
from torch.nn import functional as F
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import get_device
# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class RealESRGANer():
"""A helper class for upsampling images with RealESRGAN.
Args:
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
model (nn.Module): The defined network. Default: None.
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
input images into tiles, and then process each of them. Finally, they will be merged into one image.
0 denotes for do not use tile. Default: 0.
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
half (float): Whether to use half precision during inference. Default: False.
"""
def __init__(self,
scale,
model_path,
model=None,
tile=0,
tile_pad=10,
pre_pad=10,
half=False,
device=None,
gpu_id=None):
self.scale = scale
self.tile_size = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
# initialize model
# if gpu_id:
# self.device = torch.device(
# f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
# else:
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.device = get_device(gpu_id) if device is None else device
# if the model_path starts with https, it will first download models to the folder: realesrgan/weights
if model_path.startswith('https://'):
model_path = load_file_from_url(
url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
# prefer to use params_ema
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
model.load_state_dict(loadnet[keyname], strict=True)
model.eval()
self.model = model.to(self.device)
if self.half:
self.model = self.model.half()
def pre_process(self, img):
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible
"""
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
self.img = img.unsqueeze(0).to(self.device)
if self.half:
self.img = self.img.half()
# pre_pad
if self.pre_pad != 0:
self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
# mod pad for divisible borders
if self.scale == 2:
self.mod_scale = 2
elif self.scale == 1:
self.mod_scale = 4
if self.mod_scale is not None:
self.mod_pad_h, self.mod_pad_w = 0, 0
_, _, h, w = self.img.size()
if (h % self.mod_scale != 0):
self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
if (w % self.mod_scale != 0):
self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
def process(self):
# model inference
self.output = self.model(self.img)
def tile_process(self):
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/ata4/esrgan-launcher
"""
batch, channel, height, width = self.img.shape
output_height = height * self.scale
output_width = width * self.scale
output_shape = (batch, channel, output_height, output_width)
# start with black image
self.output = self.img.new_zeros(output_shape)
tiles_x = math.ceil(width / self.tile_size)
tiles_y = math.ceil(height / self.tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * self.tile_size
ofs_y = y * self.tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + self.tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + self.tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
input_end_x_pad = min(input_end_x + self.tile_pad, width)
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
input_end_y_pad = min(input_end_y + self.tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
try:
with torch.no_grad():
output_tile = self.model(input_tile)
except RuntimeError as error:
print('Error', error)
# print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
# output tile area on total image
output_start_x = input_start_x * self.scale
output_end_x = input_end_x * self.scale
output_start_y = input_start_y * self.scale
output_end_y = input_end_y * self.scale
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
# put tile into output image
self.output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
def post_process(self):
# remove extra pad
if self.mod_scale is not None:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
# remove prepad
if self.pre_pad != 0:
_, _, h, w = self.output.size()
self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
return self.output
@torch.no_grad()
def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
h_input, w_input = img.shape[0:2]
# img: numpy
img = img.astype(np.float32)
if np.max(img) > 256: # 16-bit image
max_range = 65535
print('\tInput is a 16-bit image')
else:
max_range = 255
img = img / max_range
if len(img.shape) == 2: # gray image
img_mode = 'L'
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif img.shape[2] == 4: # RGBA image with alpha channel
img_mode = 'RGBA'
alpha = img[:, :, 3]
img = img[:, :, 0:3]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if alpha_upsampler == 'realesrgan':
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
else:
img_mode = 'RGB'
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# ------------------- process image (without the alpha channel) ------------------- #
try:
with torch.no_grad():
self.pre_process(img)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_img_t = self.post_process()
output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
if img_mode == 'L':
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
del output_img_t
torch.cuda.empty_cache()
except RuntimeError as error:
print(f"Failed inference for RealESRGAN: {error}")
# ------------------- process the alpha channel if necessary ------------------- #
if img_mode == 'RGBA':
if alpha_upsampler == 'realesrgan':
self.pre_process(alpha)
if self.tile_size > 0:
self.tile_process()
else:
self.process()
output_alpha = self.post_process()
output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
else: # use the cv2 resize for alpha channel
h, w = alpha.shape[0:2]
output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
# merge the alpha channel
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
output_img[:, :, 3] = output_alpha
# ------------------------------ return ------------------------------ #
if max_range == 65535: # 16-bit image
output = (output_img * 65535.0).round().astype(np.uint16)
else:
output = (output_img * 255.0).round().astype(np.uint8)
if outscale is not None and outscale != float(self.scale):
output = cv2.resize(
output, (
int(w_input * outscale),
int(h_input * outscale),
), interpolation=cv2.INTER_LANCZOS4)
return output, img_mode
class PrefetchReader(threading.Thread):
"""Prefetch images.
Args:
img_list (list[str]): A image list of image paths to be read.
num_prefetch_queue (int): Number of prefetch queue.
"""
def __init__(self, img_list, num_prefetch_queue):
super().__init__()
self.que = queue.Queue(num_prefetch_queue)
self.img_list = img_list
def run(self):
for img_path in self.img_list:
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.que.put(img)
self.que.put(None)
def __next__(self):
next_item = self.que.get()
if next_item is None:
raise StopIteration
return next_item
def __iter__(self):
return self
class IOConsumer(threading.Thread):
def __init__(self, opt, que, qid):
super().__init__()
self._queue = que
self.qid = qid
self.opt = opt
def run(self):
while True:
msg = self._queue.get()
if isinstance(msg, str) and msg == 'quit':
break
output = msg['output']
save_path = msg['save_path']
cv2.imwrite(save_path, output)
print(f'IO worker {self.qid} is done.')
================================================
FILE: basicsr/utils/registry.py
================================================
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
class Registry():
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name):
"""
Args:
name (str): the name of this registry
"""
self._name = name
self._obj_map = {}
def _do_register(self, name, obj):
assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj
def register(self, obj=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__
self._do_register(name, obj)
def get(self, name):
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
def __contains__(self, name):
return name in self._obj_map
def __iter__(self):
return iter(self._obj_map.items())
def keys(self):
return self._obj_map.keys()
DATASET_REGISTRY = Registry('dataset')
ARCH_REGISTRY = Registry('arch')
MODEL_REGISTRY = Registry('model')
LOSS_REGISTRY = Registry('loss')
METRIC_REGISTRY = Registry('metric')
================================================
FILE: basicsr/utils/video_util.py
================================================
'''
The code is modified from the Real-ESRGAN:
https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py
'''
import cv2
import sys
import numpy as np
try:
import ffmpeg
except ImportError:
import pip
pip.main(['install', '--user', 'ffmpeg-python'])
import ffmpeg
def get_video_meta_info(video_path):
ret = {}
probe = ffmpeg.probe(video_path)
video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
ret['width'] = video_streams[0]['width']
ret['height'] = video_streams[0]['height']
ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
ret['nb_frames'] = int(video_streams[0]['nb_frames'])
return ret
class VideoReader:
def __init__(self, video_path):
self.paths = [] # for image&folder type
self.audio = None
try:
self.stream_reader = (
ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
loglevel='error').run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
except FileNotFoundError:
print('Please install ffmpeg (not ffmpeg-python) by running\n',
'\t$ conda install -c conda-forge ffmpeg')
sys.exit(0)
meta = get_video_meta_info(video_path)
self.width = meta['width']
self.height = meta['height']
self.input_fps = meta['fps']
self.audio = meta['audio']
self.nb_frames = meta['nb_frames']
self.idx = 0
def get_resolution(self):
return self.height, self.width
def get_fps(self):
if self.input_fps is not None:
return self.input_fps
return 24
def get_audio(self):
return self.audio
def __len__(self):
return self.nb_frames
def get_frame_from_stream(self):
img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
if not img_bytes:
return None
img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
return img
def get_frame_from_list(self):
if self.idx >= self.nb_frames:
return None
img = cv2.imread(self.paths[self.idx])
self.idx += 1
return img
def get_frame(self):
return self.get_frame_from_stream()
def close(self):
self.stream_reader.stdin.close()
self.stream_reader.wait()
class VideoWriter:
def __init__(self, video_save_path, height, width, fps, audio):
if height > 2160:
print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
'We highly recommend to decrease the outscale(aka, -s).')
if audio is not None:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
framerate=fps).output(
audio,
video_save_path,
pix_fmt='yuv420p',
vcodec='libx264',
loglevel='error',
acodec='copy').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
else:
self.stream_writer = (
ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
framerate=fps).output(
video_save_path, pix_fmt='yuv420p', vcodec='libx264',
loglevel='error').overwrite_output().run_async(
pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
def write_frame(self, frame):
try:
frame = frame.astype(np.uint8).tobytes()
self.stream_writer.stdin.write(frame)
except BrokenPipeError:
print('Please re-install ffmpeg and libx264 by running\n',
'\t$ conda install -c conda-forge ffmpeg\n',
'\t$ conda install -c conda-forge x264')
sys.exit(0)
def close(self):
self.stream_writer.stdin.close()
self.stream_writer.wait()
================================================
FILE: basicsr/version.py
================================================
# GENERATED VERSION FILE
# TIME: Thu Jun 26 05:59:40 2025
__version__ = '1.3.2'
__gitsha__ = '536df45'
version_info = (1, 3, 2)
================================================
FILE: facelib/detection/__init__.py
================================================
import os
import torch
from torch import nn
from copy import deepcopy
from facelib.utils import load_file_from_url
from facelib.utils import download_pretrained_models
from facelib.detection.yolov5face.models.common import Conv
from .retinaface.retinaface import RetinaFace
from .yolov5face.face_detector import YoloDetector
def init_detection_model(model_name, half=False, device='cuda'):
if 'retinaface' in model_name:
model = init_retinaface_model(model_name, half, device)
elif 'YOLOv5' in model_name:
model = init_yolov5face_model(model_name, device)
else:
raise NotImplementedError(f'{model_name} is not implemented.')
return model
def init_retinaface_model(model_name, half=False, device='cuda'):
if model_name == 'retinaface_resnet50':
model = RetinaFace(network_name='resnet50', half=half)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth'
elif model_name == 'retinaface_mobile0.25':
model = RetinaFace(network_name='mobile0.25', half=half)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(url=model_url, model_dir='ckpts/facelib', progress=True, file_name=None)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
load_net[k[7:]] = v
load_net.pop(k)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model
def init_yolov5face_model(model_name, device='cuda'):
if model_name == 'YOLOv5l':
model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
elif model_name == 'YOLOv5n':
model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(url=model_url, model_dir='ckpts/facelib', progress=True, file_name=None)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
model.detector.load_state_dict(load_net, strict=True)
model.detector.eval()
model.detector = model.detector.to(device).float()
for m in model.detector.modules():
if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
m.inplace = True # pytorch 1.7.0 compatibility
elif isinstance(m, Conv):
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
return model
# Download from Google Drive
# def init_yolov5face_model(model_name, device='cuda'):
# if model_name == 'YOLOv5l':
# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
# elif model_name == 'YOLOv5n':
# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
# else:
# raise NotImplementedError(f'{model_name} is not implemented.')
# model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
# if not os.path.exists(model_path):
# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
# load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
# model.detector.load_state_dict(load_net, strict=True)
# model.detector.eval()
# model.detector = model.detector.to(device).float()
# for m in model.detector.modules():
# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
# m.inplace = True # pytorch 1.7.0 compatibility
# elif isinstance(m, Conv):
# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
# return model
================================================
FILE: facelib/detection/align_trans.py
================================================
import cv2
import numpy as np
from .matlab_cp2tform import get_similarity_transform_for_cv2
# reference facial points, a list of coordinates (x,y)
REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
[33.54930115, 92.3655014], [62.72990036, 92.20410156]]
DEFAULT_CROP_SIZE = (96, 112)
class FaceWarpException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
"""
Function:
----------
get reference 5 key points according to crop settings:
0. Set default crop_size:
if default_square:
crop_size = (112, 112)
else:
crop_size = (96, 112)
1. Pad the crop_size by inner_padding_factor in each side;
2. Resize crop_size into (output_size - outer_padding*2),
pad into output_size with outer_padding;
3. Output reference_5point;
Parameters:
----------
@output_size: (w, h) or None
size of aligned face image
@inner_padding_factor: (w_factor, h_factor)
padding factor for inner (w, h)
@outer_padding: (w_pad, h_pad)
each row is a pair of coordinates (x, y)
@default_square: True or False
if True:
default crop_size = (112, 112)
else:
default crop_size = (96, 112);
!!! make sure, if output_size is not None:
(output_size - outer_padding)
= some_scale * (default crop_size * (1.0 +
inner_padding_factor))
Returns:
----------
@reference_5point: 5x2 np.array
each row is a pair of transformed coordinates (x, y)
"""
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
# 0) make the inner region a square
if default_square:
size_diff = max(tmp_crop_size) - tmp_crop_size
tmp_5pts += size_diff / 2
tmp_crop_size += size_diff
if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
return tmp_5pts
if (inner_padding_factor == 0 and outer_padding == (0, 0)):
if output_size is None:
return tmp_5pts
else:
raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
# check output size
if not (0 <= inner_padding_factor <= 1.0):
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
output_size = tmp_crop_size * \
(1 + inner_padding_factor * 2).astype(np.int32)
output_size += np.array(outer_padding)
if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
# 1) pad the inner region according inner_padding_factor
if inner_padding_factor > 0:
size_diff = tmp_crop_size * inner_padding_factor * 2
tmp_5pts += size_diff / 2
tmp_crop_size += np.round(size_diff).astype(np.int32)
# 2) resize the padded inner region
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
raise FaceWarpException('Must have (output_size - outer_padding)'
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')
scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
tmp_5pts = tmp_5pts * scale_factor
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
# tmp_5pts = tmp_5pts + size_diff / 2
tmp_crop_size = size_bf_outer_pad
# 3) add outer_padding to make output_size
reference_5point = tmp_5pts + np.array(outer_padding)
tmp_crop_size = output_size
return reference_5point
def get_affine_transform_matrix(src_pts, dst_pts):
"""
Function:
----------
get affine transform matrix 'tfm' from src_pts to dst_pts
Parameters:
----------
@src_pts: Kx2 np.array
source points matrix, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points matrix, each row is a pair of coordinates (x, y)
Returns:
----------
@tfm: 2x3 np.array
transform matrix from src_pts to dst_pts
"""
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
n_pts = src_pts.shape[0]
ones = np.ones((n_pts, 1), src_pts.dtype)
src_pts_ = np.hstack([src_pts, ones])
dst_pts_ = np.hstack([dst_pts, ones])
A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
if rank == 3:
tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
elif rank == 2:
tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
return tfm
def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@src_img: 3x3 np.array
input image
@facial_pts: could be
1)a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
@reference_pts: could be
1) a list of K coordinates (x,y)
or
2) Kx2 or 2xK np.array
each row or col is a pair of coordinates (x, y)
or
3) None
if None, use default reference facial points
@crop_size: (w, h)
output face image size
@align_type: transform type, could be one of
1) 'similarity': use similarity transform
2) 'cv2_affine': use the first 3 points to do affine transform,
by calling cv2.getAffineTransform()
3) 'affine': use all points to do affine transform
Returns:
----------
@face_img: output face image with size (w, h) = @crop_size
"""
if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size
reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
default_square)
ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T
src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
if src_pts_shp[0] == 2:
src_pts = src_pts.T
if src_pts.shape != ref_pts.shape:
raise FaceWarpException('facial_pts and reference_pts must have the same shape')
if align_type == 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
elif align_type == 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
else:
tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
return face_img
================================================
FILE: facelib/detection/matlab_cp2tform.py
================================================
import numpy as np
from numpy.linalg import inv, lstsq
from numpy.linalg import matrix_rank as rank
from numpy.linalg import norm
class MatlabCp2tormException(Exception):
def __str__(self):
return 'In File {}:{}'.format(__file__, super.__str__(self))
def tformfwd(trans, uv):
"""
Function:
----------
apply affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of transformed coordinates (x, y)
"""
uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy = np.dot(uv, trans)
xy = xy[:, 0:-1]
return xy
def tforminv(trans, uv):
"""
Function:
----------
apply the inverse of affine transform 'trans' to uv
Parameters:
----------
@trans: 3x3 np.array
transform matrix
@uv: Kx2 np.array
each row is a pair of coordinates (x, y)
Returns:
----------
@xy: Kx2 np.array
each row is a pair of inverse-transformed coordinates (x, y)
"""
Tinv = inv(trans)
xy = tformfwd(Tinv, uv)
return xy
def findNonreflectiveSimilarity(uv, xy, options=None):
options = {'K': 2}
K = options['K']
M = xy.shape[0]
x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
X = np.vstack((tmp1, tmp2))
u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
U = np.vstack((u, v))
# We know that X * r = U
if rank(X) >= 2 * K:
r, _, _, _ = lstsq(X, U, rcond=-1)
r = np.squeeze(r)
else:
raise Exception('cp2tform:twoUniquePointsReq')
sc = r[0]
ss = r[1]
tx = r[2]
ty = r[3]
Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
T = inv(Tinv)
T[:, 2] = np.array([0, 0, 1])
return T, Tinv
def findSimilarity(uv, xy, options=None):
options = {'K': 2}
# uv = np.array(uv)
# xy = np.array(xy)
# Solve for trans1
trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
# Solve for trans2
# manually reflect the xy data across the Y-axis
xyR = xy
xyR[:, 0] = -1 * xyR[:, 0]
trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
# manually reflect the tform to undo the reflection done on xyR
TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
trans2 = np.dot(trans2r, TreflectY)
# Figure out if trans1 or trans2 is better
xy1 = tformfwd(trans1, uv)
norm1 = norm(xy1 - xy)
xy2 = tformfwd(trans2, uv)
norm2 = norm(xy2 - xy)
if norm1 <= norm2:
return trans1, trans1_inv
else:
trans2_inv = inv(trans2)
return trans2, trans2_inv
def get_similarity_transform(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'trans':
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y, 1] = [u, v, 1] * trans
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
@reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
trans_inv: 3x3 np.array
inverse of trans, transform matrix from xy to uv
"""
if reflective:
trans, trans_inv = findSimilarity(src_pts, dst_pts)
else:
trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
return trans, trans_inv
def cvt_tform_mat_for_cv2(trans):
"""
Function:
----------
Convert Transform Matrix 'trans' into 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@trans: 3x3 np.array
transform matrix from uv to xy
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
cv2_trans = trans[:, 0:2].T
return cv2_trans
def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
"""
Function:
----------
Find Similarity Transform Matrix 'cv2_trans' which could be
directly used by cv2.warpAffine():
u = src_pts[:, 0]
v = src_pts[:, 1]
x = dst_pts[:, 0]
y = dst_pts[:, 1]
[x, y].T = cv_trans * [u, v, 1].T
Parameters:
----------
@src_pts: Kx2 np.array
source points, each row is a pair of coordinates (x, y)
@dst_pts: Kx2 np.array
destination points, each row is a pair of transformed
coordinates (x, y)
reflective: True or False
if True:
use reflective similarity transform
else:
use non-reflective similarity transform
Returns:
----------
@cv2_trans: 2x3 np.array
transform matrix from src_pts to dst_pts, could be directly used
for cv2.warpAffine()
"""
trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
cv2_trans = cvt_tform_mat_for_cv2(trans)
return cv2_trans
if __name__ == '__main__':
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
# In Matlab, run:
#
# uv = [u'; v'];
# xy = [x'; y'];
# tform_sim=cp2tform(uv,xy,'similarity');
#
# trans = tform_sim.tdata.T
# ans =
# -0.0764 -1.6190 0
# 1.6190 -0.0764 0
# -3.2156 0.0290 1.0000
# trans_inv = tform_sim.tdata.Tinv
# ans =
#
# -0.0291 0.6163 0
# -0.6163 -0.0291 0
# -0.0756 1.9826 1.0000
# xy_m=tformfwd(tform_sim, u,v)
#
# xy_m =
#
# -3.2156 0.0290
# 1.1833 -9.9143
# 5.0323 2.8853
# uv_m=tforminv(tform_sim, x,y)
#
# uv_m =
#
# 0.5698 1.3953
# 6.0872 2.2733
# -2.6570 4.3314
"""
u = [0, 6, -2]
v = [0, 3, 5]
x = [-1, 0, 4]
y = [-1, -10, 4]
uv = np.array((u, v)).T
xy = np.array((x, y)).T
print('\n--->uv:')
print(uv)
print('\n--->xy:')
print(xy)
trans, trans_inv = get_similarity_transform(uv, xy)
print('\n--->trans matrix:')
print(trans)
print('\n--->trans_inv matrix:')
print(trans_inv)
print('\n---> apply transform to uv')
print('\nxy_m = uv_augmented * trans')
uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
xy_m = np.dot(uv_aug, trans)
print(xy_m)
print('\nxy_m = tformfwd(trans, uv)')
xy_m = tformfwd(trans, uv)
print(xy_m)
print('\n---> apply inverse transform to xy')
print('\nuv_m = xy_augmented * trans_inv')
xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
uv_m = np.dot(xy_aug, trans_inv)
print(uv_m)
print('\nuv_m = tformfwd(trans_inv, xy)')
uv_m = tformfwd(trans_inv, xy)
print(uv_m)
uv_m = tforminv(trans, xy)
print('\nuv_m = tforminv(trans, xy)')
print(uv_m)
================================================
FILE: facelib/detection/retinaface/retinaface.py
================================================
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
py_cpu_nms)
from basicsr.utils.misc import get_device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = get_device()
def generate_config(network_name):
cfg_mnet = {
'name': 'mobilenet0.25',
'min_sizes': [[16, 32], [64, 128], [256, 512]],
'steps': [8, 16, 32],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'gpu_train': True,
'batch_size': 32,
'ngpu': 1,
'epoch': 250,
'decay1': 190,
'decay2': 220,
'image_size': 640,
'return_layers': {
'stage1': 1,
'stage2': 2,
'stage3': 3
},
'in_channel': 32,
'out_channel': 64
}
cfg_re50 = {
'name': 'Resnet50',
'min_sizes': [[16, 32], [64, 128], [256, 512]],
'steps': [8, 16, 32],
'variance': [0.1, 0.2],
'clip': False,
'loc_weight': 2.0,
'gpu_train': True,
'batch_size': 24,
'ngpu': 4,
'epoch': 100,
'decay1': 70,
'decay2': 90,
'image_size': 840,
'return_layers': {
'layer2': 1,
'layer3': 2,
'layer4': 3
},
'in_channel': 256,
'out_channel': 256
}
if network_name == 'mobile0.25':
return cfg_mnet
elif network_name == 'resnet50':
return cfg_re50
else:
raise NotImplementedError(f'network_name={network_name}')
class RetinaFace(nn.Module):
def __init__(self, network_name='resnet50', half=False, phase='test'):
super(RetinaFace, self).__init__()
self.half_inference = half
cfg = generate_config(network_name)
self.backbone = cfg['name']
self.model_name = f'retinaface_{network_name}'
self.cfg = cfg
self.phase = phase
self.target_size, self.max_size = 1600, 2150
self.resize, self.scale, self.scale1 = 1., None, None
self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
self.reference = get_reference_facial_points(default_square=True)
# Build network.
backbone = None
if cfg['name'] == 'mobilenet0.25':
backbone = MobileNetV1()
self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
elif cfg['name'] == 'Resnet50':
import torchvision.models as models
backbone = models.resnet50(pretrained=False)
self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
in_channels_stage2 = cfg['in_channel']
in_channels_list = [
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
out_channels = cfg['out_channel']
self.fpn = FPN(in_channels_list, out_channels)
self.ssh1 = SSH(out_channels, out_channels)
self.ssh2 = SSH(out_channels, out_channels)
self.ssh3 = SSH(out_channels, out_channels)
self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
self.to(device)
self.eval()
if self.half_inference:
self.half()
def forward(self, inputs):
out = self.body(inputs)
if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
out = list(out.values())
# FPN
fpn = self.fpn(out)
# SSH
feature1 = self.ssh1(fpn[0])
feature2 = self.ssh2(fpn[1])
feature3 = self.ssh3(fpn[2])
features = [feature1, feature2, feature3]
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
ldm_regressions = (torch.cat(tmp, dim=1))
if self.phase == 'train':
output = (bbox_regressions, classifications, ldm_regressions)
else:
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
return output
def __detect_faces(self, inputs):
# get scale
height, width = inputs.shape[2:]
self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
tmp = [width, height, width, height, width, height, width, height, width, height]
self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
# forawrd
inputs = inputs.to(device)
if self.half_inference:
inputs = inputs.half()
loc, conf, landmarks = self(inputs)
# get priorbox
priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
priors = priorbox.forward().to(device)
return loc, conf, landmarks, priors
# single image detection
def transform(self, image, use_origin_size):
# convert to opencv format
if isinstance(image, Image.Image):
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
image = image.astype(np.float32)
# testing scale
im_size_min = np.min(image.shape[0:2])
im_size_max = np.max(image.shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
# convert to torch.tensor format
# image -= (104, 117, 123)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).unsqueeze(0)
return image, resize
def detect_faces(
self,
image,
conf_threshold=0.8,
nms_threshold=0.4,
use_origin_size=True,
):
"""
Params:
imgs: BGR image
"""
image, self.resize = self.transform(image, use_origin_size)
image = image.to(device)
if self.half_inference:
image = image.half()
image = image - self.mean_tensor
loc, conf, landmarks, priors = self.__detect_faces(image)
boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
boxes = boxes * self.scale / self.resize
boxes = boxes.cpu().numpy()
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
landmarks = landmarks * self.scale1 / self.resize
landmarks = landmarks.cpu().numpy()
# ignore low scores
inds = np.where(scores > conf_threshold)[0]
boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
# sort
order = scores.argsort()[::-1]
boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
# do NMS
bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
# self.t['forward_pass'].toc()
# print(self.t['forward_pass'].average_time)
# import sys
# sys.stdout.flush()
return np.concatenate((bounding_boxes, landmarks), axis=1)
def __align_multi(self, image, boxes, landmarks, limit=None):
if len(boxes) < 1:
return [], []
if limit:
boxes = boxes[:limit]
landmarks = landmarks[:limit]
faces = []
for landmark in landmarks:
facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
faces.append(warped_face)
return np.concatenate((boxes, landmarks), axis=1), faces
def align_multi(self, img, conf_threshold=0.8, limit=None):
rlt = self.detect_faces(img, conf_threshold=conf_threshold)
boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
return self.__align_multi(img, boxes, landmarks, limit)
# batched detection
def batched_transform(self, frames, use_origin_size):
"""
Arguments:
frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
type=np.float32, BGR format).
use_origin_size: whether to use origin size.
"""
from_PIL = True if isinstance(frames[0], Image.Image) else False
# convert to opencv format
if from_PIL:
frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
frames = np.asarray(frames, dtype=np.float32)
# testing scale
im_size_min = np.min(frames[0].shape[0:2])
im_size_max = np.max(frames[0].shape[0:2])
resize = float(self.target_size) / float(im_size_min)
# prevent bigger axis from being more than max_size
if np.round(resize * im_size_max) > self.max_size:
resize = float(self.max_size) / float(im_size_max)
resize = 1 if use_origin_size else resize
# resize
if resize != 1:
if not from_PIL:
frames = F.interpolate(frames, scale_factor=resize)
else:
frames = [
cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
for frame in frames
]
# convert to torch.tensor format
if not from_PIL:
frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
else:
frames = frames.transpose((0, 3, 1, 2))
frames = torch.from_numpy(frames)
return frames, resize
def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
"""
Arguments:
frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
type=np.uint8, BGR format).
conf_threshold: confidence threshold.
nms_threshold: nms threshold.
use_origin_size: whether to use origin size.
Returns:
final_bounding_boxes: list of np.array ([n_boxes, 5],
type=np.float32).
final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
"""
# self.t['forward_pass'].tic()
frames, self.resize = self.batched_transform(frames, use_origin_size)
frames = frames.to(device)
frames = frames - self.mean_tensor
b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
final_bounding_boxes, final_landmarks = [], []
# decode
priors = priors.unsqueeze(0)
b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
b_conf = b_conf[:, :, 1]
# index for selection
b_indice = b_conf > conf_threshold
# concat
b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
# ignore low scores
pred, landm = pred[inds, :], landm[inds, :]
if pred.shape[0] == 0:
final_bounding_boxes.append(np.array([], dtype=np.float32))
final_landmarks.append(np.array([], dtype=np.float32))
continue
# sort
# order = score.argsort(descending=True)
# box, landm, score = box[order], landm[order], score[order]
# to CPU
bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
# NMS
keep = py_cpu_nms(bounding_boxes, nms_threshold)
bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
# append
final_bounding_boxes.append(bounding_boxes)
final_landmarks.append(landmarks)
# self.t['forward_pass'].toc(average=True)
# self.batch_time += self.t['forward_pass'].diff
# self.total_frame += len(frames)
# print(self.batch_time / self.total_frame)
return final_bounding_boxes, final_landmarks
================================================
FILE: facelib/detection/retinaface/retinaface_net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv_bn(inp, oup, stride=1, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_bn_no_relu(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
)
def conv_bn1X1(inp, oup, stride, leaky=0):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True))
def conv_dw(inp, oup, stride, leaky=0.1):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.LeakyReLU(negative_slope=leaky, inplace=True),
)
class SSH(nn.Module):
def __init__(self, in_channel, out_channel):
super(SSH, self).__init__()
assert out_channel % 4 == 0
leaky = 0
if (out_channel <= 64):
leaky = 0.1
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
def forward(self, input):
conv3X3 = self.conv3X3(input)
conv5X5_1 = self.conv5X5_1(input)
conv5X5 = self.conv5X5_2(conv5X5_1)
conv7X7_2 = self.conv7X7_2(conv5X5_1)
conv7X7 = self.conv7x7_3(conv7X7_2)
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
out = F.relu(out)
return out
class FPN(nn.Module):
def __init__(self, in_channels_list, out_channels):
super(FPN, self).__init__()
leaky = 0
if (out_channels <= 64):
leaky = 0.1
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
def forward(self, input):
# names = list(input.keys())
# input = list(input.values())
output1 = self.output1(input[0])
output2 = self.output2(input[1])
output3 = self.output3(input[2])
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)
out = [output1, output2, output3]
return out
class MobileNetV1(nn.Module):
def __init__(self):
super(MobileNetV1, self).__init__()
self.stage1 = nn.Sequential(
conv_bn(3, 8, 2, leaky=0.1), # 3
conv_dw(8, 16, 1), # 7
conv_dw(16, 32, 2), # 11
conv_dw(32, 32, 1), # 19
conv_dw(32, 64, 2), # 27
conv_dw(64, 64, 1), # 43
)
self.stage2 = nn.Sequential(
conv_dw(64, 128, 2), # 43 + 16 = 59
conv_dw(128, 128, 1), # 59 + 32 = 91
conv_dw(128, 128, 1), # 91 + 32 = 123
conv_dw(128, 128, 1), # 123 + 32 = 155
conv_dw(128, 128, 1), # 155 + 32 = 187
conv_dw(128, 128, 1), # 187 + 32 = 219
)
self.stage3 = nn.Sequential(
conv_dw(128, 256, 2), # 219 +3 2 = 241
conv_dw(256, 256, 1), # 241 + 64 = 301
)
self.avg = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(256, 1000)
def forward(self, x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.avg(x)
# x = self.model(x)
x = x.view(-1, 256)
x = self.fc(x)
return x
class ClassHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(ClassHead, self).__init__()
self.num_anchors = num_anchors
self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 2)
class BboxHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(BboxHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 4)
class LandmarkHead(nn.Module):
def __init__(self, inchannels=512, num_anchors=3):
super(LandmarkHead, self).__init__()
self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
def forward(self, x):
out = self.conv1x1(x)
out = out.permute(0, 2, 3, 1).contiguous()
return out.view(out.shape[0], -1, 10)
def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
classhead = nn.ModuleList()
for i in range(fpn_num):
classhead.append(ClassHead(inchannels, anchor_num))
return classhead
def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
bboxhead = nn.ModuleList()
for i in range(fpn_num):
bboxhead.append(BboxHead(inchannels, anchor_num))
return bboxhead
def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
landmarkhead = nn.ModuleList()
for i in range(fpn_num):
landmarkhead.append(LandmarkHead(inchannels, anchor_num))
return landmarkhead
================================================
FILE: facelib/detection/retinaface/retinaface_utils.py
================================================
import numpy as np
import torch
import torchvision
from itertools import product as product
from math import ceil
class PriorBox(object):
def __init__(self, cfg, image_size=None, phase='train'):
super(PriorBox, self).__init__()
self.min_sizes = cfg['min_sizes']
self.steps = cfg['steps']
self.clip = cfg['clip']
self.image_size = image_size
self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
self.name = 's'
def forward(self):
anchors = []
for k, f in enumerate(self.feature_maps):
min_sizes = self.min_sizes[k]
for i, j in product(range(f[0]), range(f[1])):
for min_size in min_sizes:
s_kx = min_size / self.image_size[1]
s_ky = min_size / self.image_size[0]
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
for cy, cx in product(dense_cy, dense_cx):
anchors += [cx, cy, s_kx, s_ky]
# back to torch land
output = torch.Tensor(anchors).view(-1, 4)
if self.clip:
output.clamp_(max=1, min=0)
return output
def py_cpu_nms(dets, thresh):
"""Pure Python NMS baseline."""
keep = torchvision.ops.nms(
boxes=torch.Tensor(dets[:, :4]),
scores=torch.Tensor(dets[:, 4]),
iou_threshold=thresh,
)
return list(keep)
def point_form(boxes):
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
representation for comparison to point form ground truth data.
Args:
boxes: (tensor) center-size default boxes from priorbox layers.
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(
boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
boxes[:, :2] + boxes[:, 2:] / 2),
1) # xmax, ymax
def center_size(boxes):
""" Convert prior_boxes to (cx, cy, w, h)
representation for comparison to center-size form ground truth data.
Args:
boxes: (tensor) point_form boxes
Return:
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
"""
return torch.cat(
(boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
boxes[:, 2:] - boxes[:, :2],
1) # w, h
def intersect(box_a, box_b):
""" We resize both tensors to [A,B,2] without new malloc:
[A,2] -> [A,1,2] -> [A,B,2]
[B,2] -> [1,B,2] -> [A,B,2]
Then we compute the area of intersect between box_a and box_b.
Args:
box_a: (tensor) bounding boxes, Shape: [A,4].
box_b: (tensor) bounding boxes, Shape: [B,4].
Return:
(tensor) intersection area, Shape: [A,B].
"""
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
return inter[:, :, 0] * inter[:, :, 1]
def jaccard(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
is simply the intersection over union of two boxes. Here we operate on
ground truth boxes and default boxes.
E.g.:
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
Args:
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
Return:
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
"""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
union = area_a + area_b - inter
return inter / union # [A,B]
def matrix_iou(a, b):
"""
return iou of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
def matrix_iof(a, b):
"""
return iof of a and b, numpy version for data augenmentation
"""
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
return area_i / np.maximum(area_a[:, np.newaxis], 1)
def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
"""Match each prior box with the ground truth box of the highest jaccard
overlap, encode the bounding boxes, then return the matched indices
corresponding to both confidence and location preds.
Args:
threshold: (float) The overlap threshold used when matching boxes.
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
variances: (tensor) Variances corresponding to each prior coord,
Shape: [num_priors, 4].
labels: (tensor) All the class labels for the image, Shape: [num_obj].
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
loc_t: (tensor) Tensor to be filled w/ encoded location targets.
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
idx: (int) current batch index
Return:
The matched indices corresponding to 1)location 2)confidence
3)landm preds.
"""
# jaccard index
overlaps = jaccard(truths, point_form(priors))
# (Bipartite Matching)
# [1,num_objects] best prior for each ground truth
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
# ignore hard gt
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
if best_prior_idx_filter.shape[0] <= 0:
loc_t[idx] = 0
conf_t[idx] = 0
return
# [1,num_priors] best ground truth for each prior
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
best_truth_idx.squeeze_(0)
best_truth_overlap.squeeze_(0)
best_prior_idx.squeeze_(1)
best_prior_idx_filter.squeeze_(1)
best_prior_overlap.squeeze_(1)
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
# TODO refactor: index best_prior_idx with long tensor
# ensure every gt matches with its prior of max overlap
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
best_truth_idx[best_prior_idx[j]] = j
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
loc = encode(matches, priors, variances)
matches_landm = landms[best_truth_idx]
landm = encode_landm(matches_landm, priors, variances)
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
conf_t[idx] = conf # [num_priors] top class label for each prior
landm_t[idx] = landm
def encode(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 4].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded boxes (tensor), Shape: [num_priors, 4]
"""
# dist b/t match center and prior's center
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, 2:])
# match wh / prior wh
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
g_wh = torch.log(g_wh) / variances[1]
# return target for smooth_l1_loss
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
def encode_landm(matched, priors, variances):
"""Encode the variances from the priorbox layers into the ground truth boxes
we have matched (based on jaccard overlap) with the prior boxes.
Args:
matched: (tensor) Coords of ground truth for each prior in point-form
Shape: [num_priors, 10].
priors: (tensor) Prior boxes in center-offset form
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
encoded landm (tensor), Shape: [num_priors, 10]
"""
# dist b/t match center and prior's center
matched = torch.reshape(matched, (matched.size(0), 5, 2))
priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
# encode variance
g_cxcy /= (variances[0] * priors[:, :, 2:])
# g_cxcy /= priors[:, :, 2:]
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
# return target for smooth_l1_loss
return g_cxcy
# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
loc (tensor): location predictions for loc layers,
Shape: [num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
boxes[:, :2] -= boxes[:, 2:] / 2
boxes[:, 2:] += boxes[:, :2]
return boxes
def decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
tmp = (
priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
)
landms = torch.cat(tmp, dim=1)
return landms
def batched_decode(b_loc, priors, variances):
"""Decode locations from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
b_loc (tensor): location predictions for loc layers,
Shape: [num_batches,num_priors,4]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded bounding box predictions
"""
boxes = (
priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
)
boxes = torch.cat(boxes, dim=2)
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
return boxes
def batched_decode_landm(pre, priors, variances):
"""Decode landm from predictions using priors to undo
the encoding we did for offset regression at train time.
Args:
pre (tensor): landm predictions for loc layers,
Shape: [num_batches,num_priors,10]
priors (tensor): Prior boxes in center-offset form.
Shape: [1,num_priors,4].
variances: (list[float]) Variances of priorboxes
Return:
decoded landm predictions
"""
landms = (
priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
)
landms = torch.cat(landms, dim=2)
return landms
def log_sum_exp(x):
"""Utility function for computing log_sum_exp while determining
This will be used to determine unaveraged confidence loss across
all examples in a batch.
Args:
x (Variable(tensor)): conf_preds from conf layers
"""
x_max = x.data.max()
return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
# Original author: Francisco Massa:
# https://github.com/fmassa/object-detection.torch
# Ported to PyTorch by Max deGroot (02/01/2017)
def nms(boxes, scores, overlap=0.5, top_k=200):
"""Apply non-maximum suppression at test time to avoid detecting too many
overlapping bounding boxes for a given object.
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
top_k: (int) The Maximum number of box preds to consider.
Return:
The indices of the kept boxes with respect to num_priors.
"""
keep = torch.Tensor(scores.size(0)).fill_(0).long()
if boxes.numel() == 0:
return keep
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
area = torch.mul(x2 - x1, y2 - y1)
v, idx = scores.sort(0) # sort in ascending order
# I = I[v >= 0.01]
idx = idx[-top_k:] # indices of the top-k largest vals
xx1 = boxes.new()
yy1 = boxes.new()
xx2 = boxes.new()
yy2 = boxes.new()
w = boxes.new()
h = boxes.new()
# keep = torch.Tensor()
count = 0
while idx.numel() > 0:
i = idx[-1] # index of current largest val
# keep.append(i)
keep[count] = i
count += 1
if idx.size(0) == 1:
break
idx = idx[:-1] # remove kept element from view
# load bboxes of next highest vals
torch.index_select(x1, 0, idx, out=xx1)
torch.index_select(y1, 0, idx, out=yy1)
torch.index_select(x2, 0, idx, out=xx2)
torch.index_select(y2, 0, idx, out=yy2)
# store element-wise max with next highest score
xx1 = torch.clamp(xx1, min=x1[i])
yy1 = torch.clamp(yy1, min=y1[i])
xx2 = torch.clamp(xx2, max=x2[i])
yy2 = torch.clamp(yy2, max=y2[i])
w.resize_as_(xx2)
h.resize_as_(yy2)
w = xx2 - xx1
h = yy2 - yy1
# check sizes of xx1 and xx2.. after each iteration
w = torch.clamp(w, min=0.0)
h = torch.clamp(h, min=0.0)
inter = w * h
# IoU = i / (area(a) + area(b) - i)
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
union = (rem_areas - inter) + area[i]
IoU = inter / union # store result in iou
# keep only elements with an IoU <= overlap
idx = idx[IoU.le(overlap)]
return keep, count
================================================
FILE: facelib/detection/yolov5face/__init__.py
================================================
================================================
FILE: facelib/detection/yolov5face/face_detector.py
================================================
import cv2
import copy
import re
import torch
import numpy as np
from pathlib import Path
from facelib.detection.yolov5face.models.yolo import Model
from facelib.detection.yolov5face.utils.datasets import letterbox
from facelib.detection.yolov5face.utils.general import (
check_img_size,
non_max_suppression_face,
scale_coords,
scale_coords_landmarks,
)
# IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9)
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 9, 0]
def isListempty(inList):
if isinstance(inList, list): # Is a list
return all(map(isListempty, inList))
return False # Not a list
class YoloDetector:
def __init__(
self,
config_name,
min_face=10,
target_size=None,
device='cuda',
):
"""
config_name: name of .yaml config with network configuration from models/ folder.
min_face : minimal face size in pixels.
target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
None for original resolution.
"""
self._class_path = Path(__file__).parent.absolute()
self.target_size = target_size
self.min_face = min_face
self.detector = Model(cfg=config_name)
self.device = device
def _preprocess(self, imgs):
"""
Preprocessing image before passing through the network. Resize and conversion to torch tensor.
"""
pp_imgs = []
for img in imgs:
h0, w0 = img.shape[:2] # orig hw
if self.target_size:
r = self.target_size / min(h0, w0) # resize image to img_size
if r < 1:
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
img = letterbox(img, new_shape=imgsz)[0]
pp_imgs.append(img)
pp_imgs = np.array(pp_imgs)
pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
pp_imgs = pp_imgs.float() # uint8 to fp16/32
return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
"""
Postprocessing of raw pytorch model output.
Returns:
bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
"""
bboxes = [[] for _ in range(len(origimgs))]
landmarks = [[] for _ in range(len(origimgs))]
pred = non_max_suppression_face(pred, conf_thres, iou_thres)
for image_id, origimg in enumerate(origimgs):
img_shape = origimg.shape
image_height, image_width = img_shape[:2]
gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
det = pred[image_id].cpu()
scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
for j in range(det.size()[0]):
box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
box = list(
map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
)
if box[3] - box[1] < self.min_face:
continue
lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
bboxes[image_id].append(box)
landmarks[image_id].append(lm)
return bboxes, landmarks
def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
"""
Get bbox coordinates and keypoints of faces on original image.
Params:
imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
conf_thres: confidence threshold for each prediction
iou_thres: threshold for NMS (filter of intersecting bboxes)
Returns:
bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
"""
# Pass input images through face detector
images = imgs if isinstance(imgs, list) else [imgs]
images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
origimgs = copy.deepcopy(images)
images = self._preprocess(images)
if IS_HIGH_VERSION:
with torch.inference_mode(): # for pytorch>=1.9
pred = self.detector(images)[0]
else:
with torch.no_grad(): # for pytorch<1.9
pred = self.detector(images)[0]
bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
# return bboxes, points
if not isListempty(points):
bboxes = np.array(bboxes).reshape(-1,4)
points = np.array(points).reshape(-1,10)
padding = bboxes[:,0].reshape(-1,1)
return np.concatenate((bboxes, padding, points), axis=1)
else:
return None
def __call__(self, *args):
return self.predict(*args)
================================================
FILE: facelib/detection/yolov5face/models/__init__.py
================================================
================================================
FILE: facelib/detection/yolov5face/models/common.py
================================================
# This file contains modules common to various models
import math
import numpy as np
import torch
from torch import nn
from facelib.detection.yolov5face.utils.datasets import letterbox
from facelib.detection.yolov5face.utils.general import (
make_divisible,
non_max_suppression,
scale_coords,
xyxy2xywh,
)
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc")
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
return x.view(batchsize, -1, height, width)
def DWConv(c1, c2, k=1, s=1, act=True):
# Depthwise convolution
return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def fuseforward(self, x):
return self.act(self.conv(x))
class StemBlock(nn.Module):
def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
super().__init__()
self.stem_1 = Conv(c1, c2, k, s, p, g, act)
self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0)
self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1)
self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0)
def forward(self, x):
stem_1_out = self.stem_1(x)
stem_2a_out = self.stem_2a(stem_1_out)
stem_2b_out = self.stem_2b(stem_2a_out)
stem_2p_out = self.stem_2p(stem_1_out)
return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1))
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class BottleneckCSP(nn.Module):
# CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
self.cv4 = Conv(2 * c_, c2, 1, 1)
self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
self.act = nn.LeakyReLU(0.1, inplace=True)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
y1 = self.cv3(self.m(self.cv1(x)))
y2 = self.cv2(x)
return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
class C3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
class ShuffleV2Block(nn.Module):
def __init__(self, inp, oup, stride):
super().__init__()
if not 1 <= stride <= 3:
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.SiLU(),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.SiLU(),
self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.SiLU(),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x):
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class SPP(nn.Module):
# Spatial pyramid pooling layer used in YOLOv3-SPP
def __init__(self, c1, c2, k=(5, 9, 13)):
super().__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
def forward(self, x):
x = self.cv1(x)
return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
class Focus(nn.Module):
# Focus wh information into c-space
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
super().__init__()
self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
class Concat(nn.Module):
# Concatenate a list of tensors along dimension
def __init__(self, dimension=1):
super().__init__()
self.d = dimension
def forward(self, x):
return torch.cat(x, self.d)
class NMS(nn.Module):
# Non-Maximum Suppression (NMS) module
conf = 0.25 # confidence threshold
iou = 0.45 # IoU threshold
classes = None # (optional list) filter by class
def forward(self, x):
return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
class AutoShape(nn.Module):
# input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
img_size = 640 # inference size (pixels)
conf = 0.25 # NMS confidence threshold
iou = 0.45 # NMS IoU threshold
classes = None # (optional list) filter by class
def __init__(self, model):
super().__init__()
self.model = model.eval()
def autoshape(self):
print("autoShape already enabled, skipping... ") # model already converted to model.autoshape()
return self
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=720, width=1280, RGB images example inputs are:
# OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
# PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
# numpy: = np.zeros((720,1280,3)) # HWC
# torch: = torch.zeros(16,3,720,1280) # BCHW
# multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
# Pre-process
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
shape0, shape1 = [], [] # image and inference shapes
for i, im in enumerate(imgs):
im = np.array(im) # to numpy
if im.shape[0] < 5: # image in CHW
im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
s = im.shape[:2] # HWC
shape0.append(s) # image shape
g = size / max(s) # gain
shape1.append([y * g for y in s])
imgs[i] = im # update
shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
x = np.stack(x, 0) if n > 1 else x[0][None] # stack
x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32
# Inference
with torch.no_grad():
y = self.model(x, augment, profile)[0] # forward
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
# Post-process
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
return Detections(imgs, y, self.names)
class Detections:
# detections class for YOLOv5 inference results
def __init__(self, imgs, pred, names=None):
super().__init__()
d = pred[0].device # device
gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations
self.imgs = imgs # list of images as numpy arrays
self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
self.names = names # class names
self.xyxy = pred # xyxy pixels
self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
self.n = len(self.pred)
def __len__(self):
return self.n
def tolist(self):
# return a list of Detections objects, i.e. 'for result in results.tolist():'
x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
for d in x:
for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]:
setattr(d, k, getattr(d, k)[0]) # pop out of list
return x
================================================
FILE: facelib/detection/yolov5face/models/experimental.py
================================================
# # This file contains experimental modules
import numpy as np
import torch
from torch import nn
from facelib.detection.yolov5face.models.common import Conv
class CrossConv(nn.Module):
# Cross Convolution Downsample
def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
# ch_in, ch_out, kernel, stride, groups, expansion, shortcut
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, (1, k), (1, s))
self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class MixConv2d(nn.Module):
# Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
super().__init__()
groups = len(k)
if equal_ch: # equal c_ per group
i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices
c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
else: # equal weight.numel() per group
b = [c2] + [0] * groups
a = np.eye(groups + 1, groups, k=-1)
a -= np.roll(a, 1, axis=1)
a *= np.array(k) ** 2
a[0] = 1
c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
self.bn = nn.BatchNorm2d(c2)
self.act = nn.LeakyReLU(0.1, inplace=True)
def forward(self, x):
return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
================================================
FILE: facelib/detection/yolov5face/models/yolo.py
================================================
import math
from copy import deepcopy
from pathlib import Path
import torch
import yaml # for torch hub
from torch import nn
from facelib.detection.yolov5face.models.common import (
C3,
NMS,
SPP,
AutoShape,
Bottleneck,
BottleneckCSP,
Concat,
Conv,
DWConv,
Focus,
ShuffleV2Block,
StemBlock,
)
from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
from facelib.detection.yolov5face.utils.general import make_divisible
from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super().__init__()
self.nc = nc # number of classes
self.no = nc + 5 + 10 # number of outputs per anchor
self.nl = len(anchors) # number of detection layers
self.na = len(anchors[0]) // 2 # number of anchors
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer("anchors", a) # shape(nl,na,2)
self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
def forward(self, x):
z = [] # inference output
if self.export:
for i in range(self.nl):
x[i] = self.m[i](x[i])
return x
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = torch.full_like(x[i], 0)
y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
y[..., 5:15] = x[i][..., 5:15]
y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
y[..., 5:7] = (
y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x1 y1
y[..., 7:9] = (
y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x2 y2
y[..., 9:11] = (
y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x3 y3
y[..., 11:13] = (
y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x4 y4
y[..., 13:15] = (
y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
) # landmark x5 y5
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
@staticmethod
def _make_grid(nx=20, ny=20):
# yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
class Model(nn.Module):
def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes
super().__init__()
self.yaml_file = Path(cfg).name
with Path(cfg).open(encoding="utf8") as f:
self.yaml = yaml.safe_load(f) # model dict
# Define model
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
if nc and nc != self.yaml["nc"]:
self.yaml["nc"] = nc # override yaml value
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
self.names = [str(i) for i in range(self.yaml["nc"])] # default names
# Build strides, anchors
m = self.model[-1] # Detect()
if isinstance(m, Detect):
s = 128 # 2x min stride
m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
m.anchors /= m.stride.view(-1, 1, 1)
check_anchor_order(m)
self.stride = m.stride
self._initialize_biases() # only run once
def forward(self, x):
return self.forward_once(x) # single-scale inference, train
def forward_once(self, x):
y = [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
return x
def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
# https://arxiv.org/abs/1708.02002 section 3.3
m = self.model[-1] # Detect() module
for mi, s in zip(m.m, m.stride): # from
b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
def _print_biases(self):
m = self.model[-1] # Detect() module
for mi in m.m: # from
b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
print("Fusing layers... ")
for m in self.model.modules():
if isinstance(m, Conv) and hasattr(m, "bn"):
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
delattr(m, "bn") # remove batchnorm
m.forward = m.fuseforward # update forward
elif type(m) is nn.Upsample:
m.recompute_scale_factor = None # torch 1.11.0 compatibility
return self
def nms(self, mode=True): # add or remove NMS module
present = isinstance(self.model[-1], NMS) # last layer is NMS
if mode and not present:
print("Adding NMS... ")
m = NMS() # module
m.f = -1 # from
m.i = self.model[-1].i + 1 # index
self.model.add_module(name=str(m.i), module=m) # add
self.eval()
elif not mode and present:
print("Removing NMS... ")
self.model = self.model[:-1] # remove
return self
def autoshape(self): # add autoShape module
print("Adding autoShape... ")
m = AutoShape(self) # wrap model
copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes
return m
def parse_model(d, ch): # model_dict, input_channels(3)
anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
m = eval(m) if isinstance(m, str) else m # eval strings
for j, a in enumerate(args):
try:
args[j] = eval(a) if isinstance(a, str) else a # eval strings
except:
pass
n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [
Conv,
Bottleneck,
SPP,
DWConv,
MixConv2d,
Focus,
CrossConv,
BottleneckCSP,
C3,
ShuffleV2Block,
StemBlock,
]:
c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
args = [c1, c2, *args[1:]]
if m in [BottleneckCSP, C3]:
args.insert(2, n)
n = 1
elif m is nn.BatchNorm2d:
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
elif m is Detect:
args.append([ch[x + 1] for x in f])
if isinstance(args[1], int): # number of anchors
args[1] = [list(range(args[1] * 2))] * len(f)
else:
c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace("__main__.", "") # module type
np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
layers.append(m_)
ch.append(c2)
return nn.Sequential(*layers), sorted(save)
================================================
FILE: facelib/detection/yolov5face/models/yolov5l.yaml
================================================
# parameters
nc: 1 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# anchors
anchors:
- [4,5, 8,10, 13,16] # P3/8
- [23,29, 43,55, 73,105] # P4/16
- [146,217, 231,300, 335,433] # P5/32
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 2-P3/8
[-1, 9, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 4-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 6-P5/32
[-1, 1, SPP, [1024, [3,5,7]]],
[-1, 3, C3, [1024, False]], # 8
]
# YOLOv5 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 5], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 12
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 3], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 16 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 13], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 19 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 9], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 22 (P5/32-large)
[[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
================================================
FILE: facelib/detection/yolov5face/models/yolov5n.yaml
================================================
# parameters
nc: 1 # number of classes
depth_multiple: 1.0 # model depth multiple
width_multiple: 1.0 # layer channel multiple
# anchors
anchors:
- [4,5, 8,10, 13,16] # P3/8
- [23,29, 43,55, 73,105] # P4/16
- [146,217, 231,300, 335,433] # P5/32
# YOLOv5 backbone
backbone:
# [from, number, module, args]
[[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4
[-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
[-1, 3, ShuffleV2Block, [128, 1]], # 2
[-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
[-1, 7, ShuffleV2Block, [256, 1]], # 4
[-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
[-1, 3, ShuffleV2Block, [512, 1]], # 6
]
# YOLOv5 head
head:
[[-1, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P4
[-1, 1, C3, [128, False]], # 10
[-1, 1, Conv, [128, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 2], 1, Concat, [1]], # cat backbone P3
[-1, 1, C3, [128, False]], # 14 (P3/8-small)
[-1, 1, Conv, [128, 3, 2]],
[[-1, 11], 1, Concat, [1]], # cat head P4
[-1, 1, C3, [128, False]], # 17 (P4/16-medium)
[-1, 1, Conv, [128, 3, 2]],
[[-1, 7], 1, Concat, [1]], # cat head P5
[-1, 1, C3, [128, False]], # 20 (P5/32-large)
[[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
================================================
FILE: facelib/detection/yolov5face/utils/__init__.py
================================================
================================================
FILE: facelib/detection/yolov5face/utils/autoanchor.py
================================================
# Auto-anchor utils
def check_anchor_order(m):
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
a = m.anchor_grid.prod(-1).view(-1) # anchor area
da = a[-1] - a[0] # delta a
ds = m.stride[-1] - m.stride[0] # delta s
if da.sign() != ds.sign(): # same order
print("Reversing anchor order")
m.anchors[:] = m.anchors.flip(0)
m.anchor_grid[:] = m.anchor_grid.flip(0)
================================================
FILE: facelib/detection/yolov5face/utils/datasets.py
================================================
import cv2
import numpy as np
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
elif scale_fill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)
================================================
FILE: facelib/detection/yolov5face/utils/extract_ckpt.py
================================================
import torch
import sys
sys.path.insert(0,'./facelib/detection/yolov5face')
model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
torch.save(model.state_dict(),'ckpts/facelib/yolov5n-face.pth')
================================================
FILE: facelib/detection/yolov5face/utils/general.py
================================================
import math
import time
import numpy as np
import torch
import torchvision
def check_img_size(img_size, s=32):
# Verify img_size is a multiple of stride s
new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
# if new_size != img_size:
# print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
return new_size
def make_divisible(x, divisor):
# Returns x evenly divisible by divisor
return math.ceil(x / divisor) * divisor
def xyxy2xywh(x):
# Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
y[:, 2] = x[:, 2] - x[:, 0] # width
y[:, 3] = x[:, 3] - x[:, 1] # height
return y
def xywh2xyxy(x):
# Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2]] -= pad[0] # x padding
coords[:, [1, 3]] -= pad[1] # y padding
coords[:, :4] /= gain
clip_coords(coords, img0_shape)
return coords
def clip_coords(boxes, img_shape):
# Clip bounding xyxy bounding boxes to image shape (height, width)
boxes[:, 0].clamp_(0, img_shape[1]) # x1
boxes[:, 1].clamp_(0, img_shape[0]) # y1
boxes[:, 2].clamp_(0, img_shape[1]) # x2
boxes[:, 3].clamp_(0, img_shape[0]) # y2
def box_iou(box1, box2):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
def box_area(box):
return (box[2] - box[0]) * (box[3] - box[1])
area1 = box_area(box1.T)
area2 = box_area(box2.T)
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
return inter / (area1[:, None] + area2 - inter)
def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results
Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
"""
nc = prediction.shape[2] - 15 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
# (pixels) maximum box width and height
max_wh = 4096
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
label = labels[xi]
v = torch.zeros((len(label), nc + 15), device=x.device)
v[:, :4] = label[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, landmarks, cls)
if multi_label:
i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 15:].max(1, keepdim=True)
x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# If none remain process next image
n = x.shape[0] # number of boxes
if not n:
continue
# Batched NMS
c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
break # time limit exceeded
return output
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
"""Performs Non-Maximum Suppression (NMS) on inference results
Returns:
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
"""
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
# (pixels) maximum box width and height
max_wh = 4096
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
label_id = labels[xi]
v = torch.zeros((len(label_id), nc + 5), device=x.device)
v[:, :4] = label_id[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)] # sort by confidence
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f"WARNING: NMS time limit {time_limit}s exceeded")
break # time limit exceeded
return output
def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
# Rescale coords (xyxy) from img1_shape to img0_shape
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
coords[:, :10] /= gain
coords[:, 0].clamp_(0, img0_shape[1]) # x1
coords[:, 1].clamp_(0, img0_shape[0]) # y1
coords[:, 2].clamp_(0, img0_shape[1]) # x2
coords[:, 3].clamp_(0, img0_shape[0]) # y2
coords[:, 4].clamp_(0, img0_shape[1]) # x3
coords[:, 5].clamp_(0, img0_shape[0]) # y3
coords[:, 6].clamp_(0, img0_shape[1]) # x4
coords[:, 7].clamp_(0, img0_shape[0]) # y4
coords[:, 8].clamp_(0, img0_shape[1]) # x5
coords[:, 9].clamp_(0, img0_shape[0]) # y5
return coords
================================================
FILE: facelib/detection/yolov5face/utils/torch_utils.py
================================================
import torch
from torch import nn
def fuse_conv_and_bn(conv, bn):
# Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = (
nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (include and k not in include) or k.startswith("_") or k in exclude:
continue
setattr(a, k, v)
================================================
FILE: facelib/parsing/__init__.py
================================================
import torch
from facelib.utils import load_file_from_url
from .bisenet import BiSeNet
from .parsenet import ParseNet
def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
if model_name == 'bisenet':
model = BiSeNet(num_class=19)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
elif model_name == 'parsenet':
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
else:
raise NotImplementedError(f'{model_name} is not implemented.')
model_path = load_file_from_url(url=model_url, model_dir='ckpts/facelib', progress=True, file_name=None)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
return model
================================================
FILE: facelib/parsing/bisenet.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from .resnet import ResNet18
class ConvBNReLU(nn.Module):
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm2d(out_chan)
def forward(self, x):
x = self.conv(x)
x = F.relu(self.bn(x))
return x
class BiSeNetOutput(nn.Module):
def __init__(self, in_chan, mid_chan, num_class):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
def forward(self, x):
feat = self.conv(x)
out = self.conv_out(feat)
return out, feat
class AttentionRefinementModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
self.bn_atten = nn.BatchNorm2d(out_chan)
self.sigmoid_atten = nn.Sigmoid()
def forward(self, x):
feat = self.conv(x)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = torch.mul(feat, atten)
return out
class ContextPath(nn.Module):
def __init__(self):
super(ContextPath, self).__init__()
self.resnet = ResNet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def forward(self, x):
feat8, feat16, feat32 = self.resnet(x)
h8, w8 = feat8.size()[2:]
h16, w16 = feat16.size()[2:]
h32, w32 = feat32.size()[2:]
avg = F.avg_pool2d(feat32, feat32.size()[2:])
avg = self.conv_avg(avg)
avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
class FeatureFusionModule(nn.Module):
def __init__(self, in_chan, out_chan):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, fsp, fcp):
fcat = torch.cat([fsp, fcp], dim=1)
feat = self.convblk(fcat)
atten = F.avg_pool2d(feat, feat.size()[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = torch.mul(feat, atten)
feat_out = feat_atten + feat
return feat_out
class BiSeNet(nn.Module):
def __init__(self, num_class):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, num_class)
self.conv_out16 = BiSeNetOutput(128, 64, num_class)
self.conv_out32 = BiSeNetOutput(128, 64, num_class)
def forward(self, x, return_feat=False):
h, w = x.size()[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
out, feat = self.conv_out(feat_fuse)
out16, feat16 = self.conv_out16(feat_cp8)
out32, feat32 = self.conv_out32(feat_cp16)
out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
if return_feat:
feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
return out, out16, out32, feat, feat16, feat32
else:
return out, out16, out32
================================================
FILE: facelib/parsing/parsenet.py
================================================
"""Modified from https://github.com/chaofengc/PSFRGAN
"""
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
class NormLayer(nn.Module):
"""Normalization Layers.
Args:
channels: input channels, for batch norm and instance norm.
input_size: input shape without batch size, for layer norm.
"""
def __init__(self, channels, normalize_shape=None, norm_type='bn'):
super(NormLayer, self).__init__()
norm_type = norm_type.lower()
self.norm_type = norm_type
if norm_type == 'bn':
self.norm = nn.BatchNorm2d(channels, affine=True)
elif norm_type == 'in':
self.norm = nn.InstanceNorm2d(channels, affine=False)
elif norm_type == 'gn':
self.norm = nn.GroupNorm(32, channels, affine=True)
elif norm_type == 'pixel':
self.norm = lambda x: F.normalize(x, p=2, dim=1)
elif norm_type == 'layer':
self.norm = nn.LayerNorm(normalize_shape)
elif norm_type == 'none':
self.norm = lambda x: x * 1.0
else:
assert 1 == 0, f'Norm type {norm_type} not support.'
def forward(self, x, ref=None):
if self.norm_type == 'spade':
return self.norm(x, ref)
else:
return self.norm(x)
class ReluLayer(nn.Module):
"""Relu Layer.
Args:
relu type: type of relu layer, candidates are
- ReLU
- LeakyReLU: default relu slope 0.2
- PRelu
- SELU
- none: direct pass
"""
def __init__(self, channels, relu_type='relu'):
super(ReluLayer, self).__init__()
relu_type = relu_type.lower()
if relu_type == 'relu':
self.func = nn.ReLU(True)
elif relu_type == 'leakyrelu':
self.func = nn.LeakyReLU(0.2, inplace=True)
elif relu_type == 'prelu':
self.func = nn.PReLU(channels)
elif relu_type == 'selu':
self.func = nn.SELU(True)
elif relu_type == 'none':
self.func = lambda x: x * 1.0
else:
assert 1 == 0, f'Relu type {relu_type} not support.'
def forward(self, x):
return self.func(x)
class ConvLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
scale='none',
norm_type='none',
relu_type='none',
use_pad=True,
bias=True):
super(ConvLayer, self).__init__()
self.use_pad = use_pad
self.norm_type = norm_type
if norm_type in ['bn']:
bias = False
stride = 2 if scale == 'down' else 1
self.scale_func = lambda x: x
if scale == 'up':
self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
self.relu = ReluLayer(out_channels, relu_type)
self.norm = NormLayer(out_channels, norm_type=norm_type)
def forward(self, x):
out = self.scale_func(x)
if self.use_pad:
out = self.reflection_pad(out)
out = self.conv2d(out)
out = self.norm(out)
out = self.relu(out)
return out
class ResidualBlock(nn.Module):
"""
Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
super(ResidualBlock, self).__init__()
if scale == 'none' and c_in == c_out:
self.shortcut_func = lambda x: x
else:
self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
scale_conf = scale_config_dict[scale]
self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
def forward(self, x):
identity = self.shortcut_func(x)
res = self.conv1(x)
res = self.conv2(res)
return identity + res
class ParseNet(nn.Module):
def __init__(self,
in_size=128,
out_size=128,
min_feat_size=32,
base_ch=64,
parsing_ch=19,
res_depth=10,
relu_type='LeakyReLU',
norm_type='bn',
ch_range=[32, 256]):
super().__init__()
self.res_depth = res_depth
act_args = {'norm_type': norm_type, 'relu_type': relu_type}
min_ch, max_ch = ch_range
ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
min_feat_size = min(in_size, min_feat_size)
down_steps = int(np.log2(in_size // min_feat_size))
up_steps = int(np.log2(out_size // min_feat_size))
# =============== define encoder-body-decoder ====================
self.encoder = []
self.encoder.append(ConvLayer(3, base_ch, 3, 1))
head_ch = base_ch
for i in range(down_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
head_ch = head_ch * 2
self.body = []
for i in range(res_depth):
self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
self.decoder = []
for i in range(up_steps):
cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
head_ch = head_ch // 2
self.encoder = nn.Sequential(*self.encoder)
self.body = nn.Sequential(*self.body)
self.decoder = nn.Sequential(*self.decoder)
self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
def forward(self, x):
feat = self.encoder(x)
x = feat + self.body(feat)
x = self.decoder(x)
out_img = self.out_img_conv(x)
out_mask = self.out_mask_conv(x)
return out_mask, out_img
================================================
FILE: facelib/parsing/resnet.py
================================================
import torch.nn as nn
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm2d(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = F.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class ResNet18(nn.Module):
def __init__(self):
super(ResNet18, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
================================================
FILE: facelib/utils/__init__.py
================================================
from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
__all__ = [
'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
]
================================================
FILE: facelib/utils/face_restoration_helper.py
================================================
import cv2
import numpy as np
import os
import torch
import pdb
import dlib
from torchvision.transforms.functional import normalize
from facelib.detection import init_detection_model
from facelib.parsing import init_parsing_model
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import get_device
dlib_model_url = {
'face_detector': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
'shape_predictor_5': 'https://github.com/jnjaby/KEEP/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
}
# is the test part
dlib_model_path = {
'face_detector': "./ckpts/dlib/mmod_human_face_detector.dat",
'shape_predictor_5' : "./ckpts/dlib/shape_predictor_5_face_landmarks.dat"
}
def get_largest_face(det_faces, h, w):
def get_location(val, length):
if val < 0:
return 0
elif val > length:
return length
else:
return val
face_areas = []
for det_face in det_faces:
left = get_location(det_face[0], w)
right = get_location(det_face[2], w)
top = get_location(det_face[1], h)
bottom = get_location(det_face[3], h)
face_area = (right - left) * (bottom - top)
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
return det_faces[largest_idx], largest_idx
def get_center_face(det_faces, h=0, w=0, center=None):
if center is not None:
center = np.array(center)
else:
center = np.array([w / 2, h / 2])
center_dist = []
for det_face in det_faces:
face_center = np.array(
[(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]
)
dist = np.linalg.norm(face_center - center)
center_dist.append(dist)
center_idx = center_dist.index(min(center_dist))
return det_faces[center_idx], center_idx
class FaceRestoreHelper(object):
"""Helper for the face restoration pipeline (base class)."""
def __init__(
self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None,
):
self.template_3points = template_3points # improve robustness
self.upscale_factor = int(upscale_factor)
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
self.face_size = (
int(face_size * self.crop_ratio[1]),
int(face_size * self.crop_ratio[0]),
)
self.det_model = det_model
if self.det_model == 'dlib':
# standard 5 landmarks for FFHQ faces with 1024 x 1024
self.face_template = np.array(
[
[686.77227723, 488.62376238],
[586.77227723, 493.59405941],
[337.91089109, 488.38613861],
[437.95049505, 493.51485149],
[513.58415842, 678.5049505],
]
)
self.face_template = self.face_template / (1024 // face_size)
elif self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
# facexlib
self.face_template = np.array(
[
[192.98138, 239.94708],
[318.90277, 240.1936],
[256.63416, 314.01935],
[201.26117, 371.41043],
[313.08905, 371.15118],
]
)
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
# [198.22603, 372.82502], [313.91018, 372.75659]])
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * \
(self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * \
(self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
else:
self.device = device
# init face detection model
if self.det_model == 'dlib':
self.face_detector, self.shape_predictor_5 = self.init_dlib(
dlib_model_path['face_detector'], dlib_model_path['shape_predictor_5'])
else:
self.face_detector = init_detection_model(
det_model, half=False, device=self.device)
# init face parsing model
self.use_parse = use_parse
self.face_parse = init_parsing_model(
model_name='parsenet', device=self.device)
def set_upscale_factor(self, upscale_factor):
self.upscale_factor = upscale_factor
def read_image(self, img):
"""img can be image path or cv2 loaded image."""
# self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
if isinstance(img, str):
img = cv2.imread(img)
if np.max(img) > 256: # 16-bit image
img = img / 65535 * 255
if len(img.shape) == 2: # gray image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4: # BGRA image with alpha channel
img = img[:, :, 0:3]
self.input_img = img
self.is_gray = is_gray(img, threshold=10)
if self.is_gray:
print('Grayscale input: True')
if min(self.input_img.shape[:2]) < 512:
f = 512.0/min(self.input_img.shape[:2])
self.input_img = cv2.resize(
self.input_img, (0, 0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
def init_dlib(self, detection_path, landmark5_path):
"""Initialize the dlib detectors and predictors."""
try:
import dlib
except ImportError:
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
# detection_path = load_file_from_url(
# url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
# landmark5_path = load_file_from_url(
# url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
return face_detector, shape_predictor_5
def get_face_landmarks_5_dlib(self,
only_keep_largest=False,
scale=1):
det_faces = self.face_detector(self.input_img, scale)
if len(det_faces) == 0:
# print('No face detected. Try to increase upsample_num_times.')
return 0
else:
if only_keep_largest:
# print('Detect several faces and only keep the largest.')
face_areas = []
for i in range(len(det_faces)):
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
det_faces[i].rect.bottom() - det_faces[i].rect.top())
face_areas.append(face_area)
largest_idx = face_areas.index(max(face_areas))
self.det_faces = [det_faces[largest_idx]]
else:
self.det_faces = det_faces
if len(self.det_faces) == 0:
return 0
for face in self.det_faces:
shape = self.shape_predictor_5(self.input_img, face.rect)
landmark = np.array([[part.x, part.y] for part in shape.parts()])
self.all_landmarks_5.append(landmark)
return len(self.all_landmarks_5)
def get_face_landmarks_5(self,
only_keep_largest=False,
only_center_face=False,
resize=None,
blur_ratio=0.01,
eye_dist_threshold=None):
if self.det_model == 'dlib':
return self.get_face_landmarks_5_dlib(only_keep_largest)
if resize is None:
scale = 1
input_img = self.input_img
else:
h, w = self.input_img.shape[0:2]
scale = resize / min(h, w)
scale = max(1, scale) # always scale up
h, w = int(h * scale), int(w * scale)
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
input_img = cv2.resize(
self.input_img, (w, h), interpolation=interp)
with torch.no_grad():
bboxes = self.face_detector.detect_faces(input_img)
if bboxes is None or bboxes.shape[0] == 0:
return 0
else:
bboxes = bboxes / scale
for bbox in bboxes:
# remove faces with too small eye distance: side faces or too small faces
eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
continue
if self.template_3points:
landmark = np.array([[bbox[i], bbox[i + 1]]
for i in range(5, 11, 2)])
else:
landmark = np.array([[bbox[i], bbox[i + 1]]
for i in range(5, 15, 2)])
self.all_landmarks_5.append(landmark)
self.det_faces.append(bbox[0:5])
if len(self.det_faces) == 0:
return 0
if only_keep_largest:
h, w, _ = self.input_img.shape
self.det_faces, largest_idx = get_largest_face(
self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
elif only_center_face:
h, w, _ = self.input_img.shape
self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
# pad blurry images
if self.pad_blur:
self.pad_input_imgs = []
for landmarks in self.all_landmarks_5:
# get landmarks
eye_left = landmarks[0, :]
eye_right = landmarks[1, :]
eye_avg = (eye_left + eye_right) * 0.5
mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1.5
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale,
np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
border = max(int(np.rint(qsize * 0.1)), 3)
# get pad
# pad: (width_left, height_top, width_right, height_bottom)
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = [
max(-pad[0] + border, 1),
max(-pad[1] + border, 1),
max(pad[2] - self.input_img.shape[0] + border, 1),
max(pad[3] - self.input_img.shape[1] + border, 1)
]
if max(pad) > 1:
# pad image
pad_img = np.pad(
self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
# modify landmark coords
landmarks[:, 0] += pad[0]
landmarks[:, 1] += pad[1]
# blur pad images
h, w, _ = pad_img.shape
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * blur_ratio)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
# blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
pad_img = pad_img.astype('float32')
pad_img += (blur_img - pad_img) * \
np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
pad_img += (np.median(pad_img, axis=(0, 1)) -
pad_img) * np.clip(mask, 0.0, 1.0)
pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
self.pad_input_imgs.append(pad_img)
else:
self.pad_input_imgs.append(np.copy(self.input_img))
return len(self.all_landmarks_5)
def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
"""Align and warp faces with face template.
"""
if self.pad_blur:
assert len(self.pad_input_imgs) == len(
self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
for idx, landmark in enumerate(self.all_landmarks_5):
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(
landmark, self.face_template, method=cv2.LMEDS
)[0]
self.affine_matrices.append(affine_matrix)
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
if self.pad_blur:
input_img = self.pad_input_imgs[idx]
else:
input_img = self.input_img
# pdb.set_trace()
cropped_face = cv2.warpAffine(
input_img,
affine_matrix,
self.face_size,
borderMode=border_mode,
borderValue=(135, 133, 132),
) # gray
self.cropped_faces.append(cropped_face)
# save the cropped face
if save_cropped_path is not None:
path = os.path.splitext(save_cropped_path)[0]
save_path = f'{path}_{idx:02d}.{self.save_ext}'
imwrite(cropped_face, save_path)
def get_inverse_affine(self, save_inverse_affine_path=None):
"""Get inverse affine matrix."""
for idx, affine_matrix in enumerate(self.affine_matrices):
inverse_affine = cv2.invertAffineTransform(affine_matrix)
inverse_affine *= self.upscale_factor
self.inverse_affine_matrices.append(inverse_affine)
# save inverse affine matrices
if save_inverse_affine_path is not None:
path, _ = os.path.splitext(save_inverse_affine_path)
save_path = f'{path}_{idx:02d}.pth'
torch.save(inverse_affine, save_path)
def add_restored_face(self, restored_face, input_face=None):
if self.is_gray:
# convert img into grayscale
restored_face = bgr2gray(restored_face)
if input_face is not None:
restored_face = adain_npy(
restored_face, input_face) # transfer the color
self.restored_faces.append(restored_face)
def paste_faces_to_input_image(
self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None
):
h, w, _ = self.input_img.shape
h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
if upsample_img is None:
# simply resize the background
# upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
upsample_img = cv2.resize(
self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR
)
else:
upsample_img = cv2.resize(
upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4
)
assert len(self.restored_faces) == len(
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
inv_mask_borders = []
for restored_face, inverse_affine in zip(
self.restored_faces, self.inverse_affine_matrices
):
if face_upsampler is not None:
restored_face = face_upsampler.enhance(
restored_face, outscale=self.upscale_factor
)[0]
inverse_affine /= self.upscale_factor
inverse_affine[:, 2] *= self.upscale_factor
face_size = (
self.face_size[0] * self.upscale_factor,
self.face_size[1] * self.upscale_factor,
)
else:
# Add an offset to inverse affine matrix, for more precise back alignment
if self.upscale_factor > 1:
extra_offset = 0.5 * self.upscale_factor
else:
extra_offset = 0
inverse_affine[:, 2] += extra_offset
face_size = self.face_size
inv_restored = cv2.warpAffine(
restored_face, inverse_affine, (w_up, h_up))
# always use square mask
mask = np.ones(face_size, dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
# remove the black borders
inv_mask_erosion = cv2.erode(
inv_mask,
np.ones(
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)),
np.uint8,
),
)
pasted_face = inv_mask_erosion[:, :, None] * inv_restored
total_face_area = np.sum(inv_mask_erosion) # // 3
# add border
if draw_box:
h, w = face_size
mask_border = np.ones((h, w, 3), dtype=np.float32)
border = int(1400 / np.sqrt(total_face_area))
mask_border[border : h - border, border : w - border, :] = 0
inv_mask_border = cv2.warpAffine(
mask_border, inverse_affine, (w_up, h_up)
)
inv_mask_borders.append(inv_mask_border)
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(
inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)
)
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(
inv_mask_center, (blur_size + 1, blur_size + 1), 0
)
if len(upsample_img.shape) == 2: # upsample_img is gray image
upsample_img = upsample_img[:, :, None]
inv_soft_mask = inv_soft_mask[:, :, None]
# cv2.imwrite("inv_soft_mask_1.png", (255 * inv_soft_mask).astype(np.uint8))
# parse mask
if self.use_parse:
# inference
face_input = cv2.resize(
restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
face_input = img2tensor(face_input.astype(
'float32') / 255., bgr2rgb=True, float32=True)
normalize(face_input, (0.5, 0.5, 0.5),
(0.5, 0.5, 0.5), inplace=True)
face_input = torch.unsqueeze(face_input, 0).to(self.device)
with torch.no_grad():
out = self.face_parse(face_input)[0]
out = out.argmax(dim=1).squeeze().cpu().numpy()
parse_mask = np.zeros(out.shape)
MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
for idx, color in enumerate(MASK_COLORMAP):
parse_mask[out == idx] = color
# blur the mask
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
# remove the black borders
thres = 10
parse_mask[:thres, :] = 0
parse_mask[-thres:, :] = 0
parse_mask[:, :thres] = 0
parse_mask[:, -thres:] = 0
parse_mask = parse_mask / 255.0
parse_mask = cv2.resize(parse_mask, face_size)
parse_mask = cv2.warpAffine(
parse_mask, inverse_affine, (w_up, h_up), flags=3
)
inv_soft_parse_mask = parse_mask[:, :, None]
# pasted_face = inv_restored
fuse_mask = (inv_soft_parse_mask < inv_soft_mask).astype('int')
inv_soft_mask = inv_soft_parse_mask * fuse_mask + inv_soft_mask * (1 - fuse_mask)
# cv2.imwrite("z_inv_soft_mask.png", (255 * inv_soft_mask).astype(np.uint8))
# cv2.imwrite("z_1-inv_soft_mask.png", (255 * (1 - inv_soft_mask)).astype(np.uint8))
# cv2.imwrite("z_upsample_img.png", upsample_img.astype(np.uint8))
# cv2.imwrite("z_pasted_face.png", pasted_face.astype(np.uint8))
# alpha channel
if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4:
alpha = upsample_img[:, :, 3:]
upsample_img = inv_soft_mask * pasted_face + \
(1 - inv_soft_mask) * upsample_img[:, :, 0:3]
upsample_img = np.concatenate((upsample_img, alpha), axis=2)
else:
upsample_img = inv_soft_mask * pasted_face + \
(1 - inv_soft_mask) * upsample_img
# cv2.imwrite("z_merged.png", upsample_img.astype(np.uint8))
# import time
# time.sleep(100)
if np.max(upsample_img) > 256: # 16-bit image
upsample_img = upsample_img.astype(np.uint16)
else:
upsample_img = upsample_img.astype(np.uint8)
# draw bounding box
if draw_box:
# upsample_input_img = cv2.resize(input_img, (w_up, h_up))
img_color = np.ones([*upsample_img.shape], dtype=np.float32)
img_color[:, :, 0] = 0
img_color[:, :, 1] = 255
img_color[:, :, 2] = 0
for inv_mask_border in inv_mask_borders:
upsample_img = inv_mask_border * img_color + \
(1 - inv_mask_border) * upsample_img
# upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
if save_path is not None:
path = os.path.splitext(save_path)[0]
save_path = f'{path}.{self.save_ext}'
imwrite(upsample_img, save_path)
return upsample_img
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []
class FaceAligner(object):
def __init__(self,
upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
template_3points=False,
pad_blur=False,
use_parse=False,
device=None):
self.template_3points = template_3points # improve robustness
self.upscale_factor = int(upscale_factor)
# the cropped face ratio based on the square face
self.crop_ratio = crop_ratio # (h, w)
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1]
>= 1), 'crop ration only supports >=1'
self.face_size = (
int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
self.det_model = det_model
if self.det_model == 'dlib':
# standard 5 landmarks for FFHQ faces with 1024 x 1024
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
[337.91089109, 488.38613861], [
437.95049505, 493.51485149],
[513.58415842, 678.5049505]])
self.face_template = self.face_template / (1024 // face_size)
elif self.template_3points:
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
else:
# standard 5 landmarks for FFHQ faces with 512 x 512
# facexlib
self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
[201.26117, 371.41043], [313.08905, 371.15118]])
# dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
# [198.22603, 372.82502], [313.91018, 372.75659]])
self.face_template = self.face_template * (face_size / 512.0)
if self.crop_ratio[0] > 1:
self.face_template[:, 1] += face_size * \
(self.crop_ratio[0] - 1) / 2
if self.crop_ratio[1] > 1:
self.face_template[:, 0] += face_size * \
(self.crop_ratio[1] - 1) / 2
self.save_ext = save_ext
self.pad_blur = pad_blur
if self.pad_blur is True:
self.template_3points = False
self.all_landmarks_5 = []
self.det_faces = []
self.affine_matrices = []
self.inverse_affine_matrices = []
self.cropped_faces = []
self.restored_faces = []
self.pad_input_imgs = []
if device is None:
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = get_device()
else:
self.device = device
def set_image(self, img):
self.input_img = img
def align_pair_face(self, img_lq, img_gt, landmarks):
img_lq = (img_lq[:, :, ::-1] * 255).round().astype(np.uint8)
img_gt = (img_gt[:, :, ::-1] * 255).round().astype(np.uint8)
self.set_image(img_gt)
img_lq, img_gt = self.align_warp_face(img_lq, img_gt, landmarks)
img_lq = img_lq[:, :, ::-1] / 255.0
img_gt = img_gt[:, :, ::-1] / 255.0
return img_lq, img_gt
def align_single_face(self, img, landmarks, border_mode='constant'):
"""Align and warp faces with face template.
Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
"""
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
img = (img[:, :, ::-1] * 255).round().astype(np.uint8)
affine_matrix = cv2.estimateAffinePartial2D(
landmarks, self.face_template, method=cv2.LMEDS)[0]
img = cv2.warpAffine(
img, affine_matrix, img.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
img = img[:, :, ::-1] / 255.0
return img
def align_warp_face(self, img_lq, img_gt, landmarks, border_mode='constant'):
"""Align and warp faces with face template.
Suppose input images are Numpy array, (h, w, c), BGR, uint8, [0, 255]
"""
# use 5 landmarks to get affine matrix
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
scale = img_gt.shape[0] / img_lq.shape[0]
# warp and crop faces
if border_mode == 'constant':
border_mode = cv2.BORDER_CONSTANT
elif border_mode == 'reflect101':
border_mode = cv2.BORDER_REFLECT101
elif border_mode == 'reflect':
border_mode = cv2.BORDER_REFLECT
affine_matrix = cv2.estimateAffinePartial2D(
landmarks, self.face_template, method=cv2.LMEDS)[0]
img_gt = cv2.warpAffine(
img_gt, affine_matrix, img_gt.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
affine_matrix = cv2.estimateAffinePartial2D(
landmarks / scale, self.face_template / scale, method=cv2.LMEDS)[0]
img_lq = cv2.warpAffine(
img_lq, affine_matrix, img_lq.shape[0:2], borderMode=border_mode, borderValue=(135, 133, 132)) # gray
return img_lq, img_gt
def clean_all(self):
self.all_landmarks_5 = []
self.restored_faces = []
self.affine_matrices = []
self.cropped_faces = []
self.inverse_affine_matrices = []
self.det_faces = []
self.pad_input_imgs = []
================================================
FILE: facelib/utils/face_utils.py
================================================
import cv2
import numpy as np
import torch
def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
left, top, right, bot = bbox
width = right - left
height = bot - top
if preserve_aspect:
width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
else:
width_increase = height_increase = increase_area
left = int(left - width_increase * width)
top = int(top - height_increase * height)
right = int(right + width_increase * width)
bot = int(bot + height_increase * height)
return (left, top, right, bot)
def get_valid_bboxes(bboxes, h, w):
left = max(bboxes[0], 0)
top = max(bboxes[1], 0)
right = min(bboxes[2], w)
bottom = min(bboxes[3], h)
return (left, top, right, bottom)
def align_crop_face_landmarks(img,
landmarks,
output_size,
transform_size=None,
enable_padding=True,
return_inverse_affine=False,
shrink_ratio=(1, 1)):
"""Align and crop face with landmarks.
The output_size and transform_size are based on width. The height is
adjusted based on shrink_ratio_h/shring_ration_w.
Modified from:
https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
Args:
img (Numpy array): Input image.
landmarks (Numpy array): 5 or 68 or 98 landmarks.
output_size (int): Output face size.
transform_size (ing): Transform size. Usually the four time of
output_size.
enable_padding (float): Default: True.
shrink_ratio (float | tuple[float] | list[float]): Shring the whole
face for height and width (crop larger area). Default: (1, 1).
Returns:
(Numpy array): Cropped face.
"""
lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
if isinstance(shrink_ratio, (float, int)):
shrink_ratio = (shrink_ratio, shrink_ratio)
if transform_size is None:
transform_size = output_size * 4
# Parse landmarks
lm = np.array(landmarks)
if lm.shape[0] == 5 and lm_type == 'retinaface_5':
eye_left = lm[0]
eye_right = lm[1]
mouth_avg = (lm[3] + lm[4]) * 0.5
elif lm.shape[0] == 5 and lm_type == 'dlib_5':
lm_eye_left = lm[2:4]
lm_eye_right = lm[0:2]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = lm[4]
elif lm.shape[0] == 68:
lm_eye_left = lm[36:42]
lm_eye_right = lm[42:48]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[48] + lm[54]) * 0.5
elif lm.shape[0] == 98:
lm_eye_left = lm[60:68]
lm_eye_right = lm[68:76]
eye_left = np.mean(lm_eye_left, axis=0)
eye_right = np.mean(lm_eye_right, axis=0)
mouth_avg = (lm[76] + lm[82]) * 0.5
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
eye_to_mouth = mouth_avg - eye_avg
# Get the oriented crop rectangle
# x: half width of the oriented crop rectangle
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
# - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
# norm with the hypotenuse: get the direction
x /= np.hypot(*x) # get the hypotenuse of a right triangle
rect_scale = 1 # TODO: you can edit it to get larger rect
x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
# y: half height of the oriented crop rectangle
y = np.flipud(x) * [-1, 1]
x *= shrink_ratio[1] # width
y *= shrink_ratio[0] # height
# c: center
c = eye_avg + eye_to_mouth * 0.1
# quad: (left_top, left_bottom, right_bottom, right_top)
quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
# qsize: side length of the square
qsize = np.hypot(*x) * 2
quad_ori = np.copy(quad)
# Shrink, for large face
# TODO: do we really need shrink
shrink = int(np.floor(qsize / output_size * 0.5))
if shrink > 1:
h, w = img.shape[0:2]
rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
quad /= shrink
qsize /= shrink
# Crop
h, w = img.shape[0:2]
border = max(int(np.rint(qsize * 0.1)), 3)
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
img = img[crop[1]:crop[3], crop[0]:crop[2], :]
quad -= crop[0:2]
# Pad
# pad: (width_left, height_top, width_right, height_bottom)
h, w = img.shape[0:2]
pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
int(np.ceil(max(quad[:, 1]))))
pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
if enable_padding and max(pad) > border - 4:
pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
h, w = img.shape[0:2]
y, x, _ = np.ogrid[:h, :w, :1]
mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
np.float32(w - 1 - x) / pad[2]),
1.0 - np.minimum(np.float32(y) / pad[1],
np.float32(h - 1 - y) / pad[3]))
blur = int(qsize * 0.02)
if blur % 2 == 0:
blur += 1
blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
img = img.astype('float32')
img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
img = np.clip(img, 0, 255) # float32, [0, 255]
quad += pad[:2]
# Transform use cv2
h_ratio = shrink_ratio[0] / shrink_ratio[1]
dst_h, dst_w = int(transform_size * h_ratio), transform_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
cropped_face = cv2.warpAffine(
img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
if output_size < transform_size:
cropped_face = cv2.resize(
cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
if return_inverse_affine:
dst_h, dst_w = int(output_size * h_ratio), output_size
template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
# use cv2.LMEDS method for the equivalence to skimage transform
# ref: https://blog.csdn.net/yichxi/article/details/115827338
affine_matrix = cv2.estimateAffinePartial2D(
quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
inverse_affine = cv2.invertAffineTransform(affine_matrix)
else:
inverse_affine = None
return cropped_face, inverse_affine
def paste_face_back(img, face, inverse_affine):
h, w = img.shape[0:2]
face_h, face_w = face.shape[0:2]
inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
mask = np.ones((face_h, face_w, 3), dtype=np.float32)
inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
# remove the black borders
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
inv_restored_remove_border = inv_mask_erosion * inv_restored
total_face_area = np.sum(inv_mask_erosion) // 3
# compute the fusion edge based on the area of face
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
blur_size = w_edge * 2
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
# float32, [0, 255]
return img
if __name__ == '__main__':
import os
from facelib.detection import init_detection_model
from facelib.utils.face_restoration_helper import get_largest_face
img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
img_name = os.splitext(os.path.basename(img_path))[0]
# initialize model
det_net = init_detection_model('retinaface_resnet50', half=False)
img_ori = cv2.imread(img_path)
h, w = img_ori.shape[0:2]
# if larger than 800, scale it
scale = max(h / 800, w / 800)
if scale > 1:
img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
with torch.no_grad():
bboxes = det_net.detect_faces(img, 0.97)
if scale > 1:
bboxes *= scale # the score is incorrect
bboxes = get_largest_face(bboxes, h, w)[0]
landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
cropped_face, inverse_affine = align_crop_face_landmarks(
img_ori,
landmarks,
output_size=512,
transform_size=None,
enable_padding=True,
return_inverse_affine=True,
shrink_ratio=(1, 1))
cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
img = paste_face_back(img_ori, cropped_face, inverse_affine)
cv2.imwrite(f'tmp/{img_name}_back.png', img)
================================================
FILE: facelib/utils/misc.py
================================================
import cv2
import os
import os.path as osp
import numpy as np
from PIL import Image
import torch
from torch.hub import download_url_to_file, get_dir
from urllib.parse import urlparse
# from basicsr.utils.download_util import download_file_from_google_drive
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def download_pretrained_models(file_ids, save_path_root):
import gdown
os.makedirs(save_path_root, exist_ok=True)
for file_name, file_id in file_ids.items():
file_url = 'https://drive.google.com/uc?id='+file_id
save_path = osp.abspath(osp.join(save_path_root, file_name))
if osp.exists(save_path):
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
if user_response.lower() == 'y':
print(f'Covering {file_name} to {save_path}')
gdown.download(file_url, save_path, quiet=False)
# download_file_from_google_drive(file_id, save_path)
elif user_response.lower() == 'n':
print(f'Skipping {file_name}')
else:
raise ValueError('Wrong input. Only accepts Y/N.')
else:
print(f'Downloading {file_name} to {save_path}')
gdown.download(file_url, save_path, quiet=False)
# download_file_from_google_drive(file_id, save_path)
def imwrite(img, file_path, params=None, auto_mkdir=True):
"""Write image to file.
Args:
img (ndarray): Image array to be written.
file_path (str): Image file path.
params (None or list): Same as opencv's :func:`imwrite` interface.
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
whether to create it automatically.
Returns:
bool: Successful or not.
"""
if auto_mkdir:
dir_name = os.path.abspath(os.path.dirname(file_path))
os.makedirs(dir_name, exist_ok=True)
return cv2.imwrite(file_path, img, params)
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative paths.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive)
def is_gray(img, threshold=10):
img = Image.fromarray(img)
if len(img.getbands()) == 1:
return True
img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
diff1 = (img1 - img2).var()
diff2 = (img2 - img3).var()
diff3 = (img3 - img1).var()
diff_sum = (diff1 + diff2 + diff3) / 3.0
if diff_sum <= threshold:
return True
else:
return False
def rgb2gray(img, out_channel=3):
r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
if out_channel == 3:
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
return gray
def bgr2gray(img, out_channel=3):
b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
if out_channel == 3:
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
return gray
def calc_mean_std(feat, eps=1e-5):
"""
Args:
feat (numpy): 3D [w h c]s
"""
size = feat.shape
assert len(size) == 3, 'The input feature should be 3D tensor.'
c = size[2]
feat_var = feat.reshape(-1, c).var(axis=0) + eps
feat_std = np.sqrt(feat_var).reshape(1, 1, c)
feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
return feat_mean, feat_std
def adain_npy(content_feat, style_feat):
"""Adaptive instance normalization for numpy.
Args:
content_feat (numpy): The input feature.
style_feat (numpy): The reference feature.
"""
size = content_feat.shape
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
================================================
FILE: options/clip5_bs2_512_align_nofix_multiscale.yaml
================================================
# general settings
name: BFR_test
model_type: CodeFormerDirichletVideoModel
num_gpu: 1
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: VFHQ-Train
type: VFHQRealDegradationDatasetNew
dataroot_gt: # replace your training data root path
global_meta_info_file: # replace with your training data meta info
dataroot_meta_info: # replace with the landmarks info of your training data
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True # make sure that dataroot_meta_info is the landmarks of your data
normalize: True
interval_list: [1]
random_reverse: True
use_flip: False
use_rot: False
blur_kernel_size: 21
kernel_list: ['iso', 'aniso']
kernel_prob: [0.7, 0.3]
blur_x_sigma: [0.1, 10]
blur_y_sigma: [0.1, 10]
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40]
crf_range: [18, 25]
vcodec: ['libx264']
vcodec_prob: [1]
# data loader
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 20
prefetch_mode: ~
val:
name: VFHQ-Test-50
type: VFHQRealDegradationDatasetNew
dataroot_gt: # replace with your test data root path
global_meta_info_file: # test data meta
dataroot_meta_info: # landmark info of your test data
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True
normalize: True
interval_list: [1]
random_reverse: False
use_flip: False
use_rot: False
blur_kernel_size: 21
kernel_list: ['iso', 'aniso']
kernel_prob: [0.7, 0.3]
blur_x_sigma: [0.1, 10]
blur_y_sigma: [0.1, 10]
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40]
crf_range: [18, 25]
vcodec: ['libx264']
vcodec_prob: [1]
# data loader
num_worker_per_gpu: 2
batch_size_per_gpu: 1
dataset_enlarge_ratio: 1
prefetch_mode: ~
# network structures
network_g:
type: TemporalCodeFormerDirDistMultiScale
dim_embed: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
# fix_modules: ['encoder','quantize', 'fuse_convs_dict', 'feat_emb']
fix_modules: [] # you can fix some module
frame_length: 5
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
n_layers: 4
# path
path:
pretrain_network_g: './ckpts/CodeFormer/codeformer.pth'
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: './ckpts/CodeFormer/vqgan_discriminator.pth'
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
cross_entropy_loss: true
entropy_loss_weight: 0.5
fidelity_weight: 0
optim_g:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
# scheduler:
# type: MultiStepLR
# milestones: [30000, 45000]
# gamma: 0.5
scheduler:
type: CosineAnnealingRestartLR
periods: [100000]
restart_weights: [1]
eta_min: !!float 2e-5
total_iter: 100000
warmup_iter: -1 # no warm up
ema_decay: 0.997
# training loss
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
dirichletKL_opt:
type: DirichletKLLoss
loss_weight: 1.00
kl_coef: 0.1
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 6000000000
manual_seed: 0
# validation settings
val:
val_freq: !!float 10 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 1 # Frequency (iterations) to print training logs to console
save_checkpoint_freq: !!float 10 # Frequency (iterations) to save model checkpoints
use_tb_logger: true # Enable TensorBoard logging
wandb:
mode: offline # Logging mode: 'offline' (local only) or 'online' (sync to Weights & Biases)
# Set to 'online' to upload training metrics to Weights & Biases
project: project_name # WandB project name
resume_id: ~ # ID to resume a previous WandB run (leave as ~ for new runs)
# dist training settings
dist_params:
backend: nccl
port: 29412
find_unused_parameters: false
================================================
FILE: options/clip5_bs2_512_align_nofix_multiscale_color.yaml
================================================
# general settings
name: codeformer_dirichlet_clip5_bs2_align_nofix_multiscale_color
model_type: CodeFormerDirichletVideoModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: VFHQ-Train
type: ColorizationDataset
dataroot_gt: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_DATA_512x512
global_meta_info_file: # path to your training data meta file
dataroot_meta_info: #
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True
normalize: True
interval_list: [1]
random_reverse: True
use_flip: False
use_rot: False
# large degradation in stageII
# blur_kernel_size: 41
blur_kernel_size: 21
kernel_list: ['iso', 'aniso'] # 模糊核的类型列表
# kernel_prob: [0.5, 0.5] # 模糊核类型的概率
kernel_prob: [0.7, 0.3]
# blur_x_sigma: [0.2, 3] # 模糊核在 x 方向的标准差范围
blur_x_sigma: [0.1, 10]
# blur_y_sigma: [0.2, 3] # 模糊核在 y 方向的标准差范围
blur_y_sigma: [0.1, 10]
# noise_range: [0, 25] # 噪声范围
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40] # 不同插值方法的概率
# use_crf: True # 是否使用crf压缩
# crf_range: [10, 30] # CRF 压缩范围
crf_range: [18, 25]
vcodec: ['libx264'] # 视频编码格式
vcodec_prob: [1] # 视频编码格式的概率
latent_gt_path: ~ # without pre-calculated latent code
# latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 10
prefetch_mode: ~
val:
name: VFHQ-Test-50
type: ColorizationDataset
# dataroot_gt: ../VFHQ_Test/VAL_cases
# global_meta_info_file: ./vfhq_val_data_info.txt
# dataroot_meta_info: ./vfhq_val_landmarks
dataroot_gt: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_Test/TEST_DATA
global_meta_info_file: ./vfhq_test.txt
dataroot_meta_info: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_Test/vfhq_test_landmarks
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True
normalize: True
interval_list: [1]
random_reverse: False
use_flip: False
use_rot: False
# large degradation in stageII
blur_kernel_size: 21
kernel_list: ['iso', 'aniso'] # 模糊核的类型列表
# kernel_prob: [0.5, 0.5] # 模糊核类型的概率
kernel_prob: [0.7, 0.3]
# blur_x_sigma: [0.2, 3] # 模糊核在 x 方向的标准差范围
blur_x_sigma: [0.1, 10]
# blur_y_sigma: [0.2, 3] # 模糊核在 y 方向的标准差范围
blur_y_sigma: [0.1, 10]
# noise_range: [0, 25] # 噪声范围
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40] # 不同插值方法的概率
# use_crf: True # 是否使用crf压缩
# crf_range: [10, 30] # CRF 压缩范围
crf_range: [18, 25]
vcodec: ['libx264'] # 视频编码格式
vcodec_prob: [1] # 视频编码格式的概率
# data loader
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 1
prefetch_mode: ~
# network structures
network_g:
type: TemporalCodeFormerDirDistMultiScale
dim_embed: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
# fix_modules: ['encoder','quantize', 'fuse_convs_dict', 'feat_emb'] # decoder 放开, generator
fix_modules: []
# vqgan_path: './weights/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN
frame_length: 5
# network_vqgan: # this config is needed if no pre-calculated latent
# type: VQAutoEncoder
# img_size: 512
# nf: 64
# ch_mult: [1, 2, 2, 4, 4, 8]
# quantizer: 'nearest'
# codebook_size: 1024
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
n_layers: 4
# path
path:
pretrain_network_g: './weights/CodeFormer/codeformer.pth'
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: './weights/CodeFormer/vqgan_discriminator.pth'
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
# use_hq_feat_loss: False
# feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
fidelity_weight: 0
optim_g:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
# scheduler:
# type: MultiStepLR
# milestones: [30000, 45000]
# gamma: 0.5
scheduler:
type: CosineAnnealingRestartLR
periods: [100000]
restart_weights: [1]
eta_min: !!float 2e-5
total_iter: 100000
warmup_iter: -1 # no warm up
ema_decay: 0.997
# training loss
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
dirichletKL_opt:
type: DirichletKLLoss
loss_weight: 0.00
kl_coef: 0.1
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 6000000000
manual_seed: 0
# validation settings
val:
val_freq: !!float 1000 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1000
use_tb_logger: true
wandb:
mode: offline
project: codeformer_dirichlet_clip5_color_dloss
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29412
find_unused_parameters: false
================================================
FILE: options/clip5_bs2_512_align_nofix_multiscale_inpaint.yaml
================================================
# general settings
name: codeformer_dirichlet_clip5_bs2_align_nofix_multiscale_inpaint
model_type: CodeFormerDirichletVideoModel
num_gpu: 8
manual_seed: 0
# dataset and data loader settings
datasets:
train:
name: VFHQ-Train
type: InpaintingDataset
dataroot_gt: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_DATA_512x512
global_meta_info_file: ./vfhq_train_data.txt
dataroot_meta_info: /sykj_002/datasets/VFHQ/VFHQ_DATAset/vfhq_train_landmarks
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True
normalize: True
interval_list: [1]
random_reverse: True
use_flip: False
use_rot: False
# large degradation in stageII
# blur_kernel_size: 41
blur_kernel_size: 21
kernel_list: ['iso', 'aniso'] # 模糊核的类型列表
# kernel_prob: [0.5, 0.5] # 模糊核类型的概率
kernel_prob: [0.7, 0.3]
# blur_x_sigma: [0.2, 3] # 模糊核在 x 方向的标准差范围
blur_x_sigma: [0.1, 10]
# blur_y_sigma: [0.2, 3] # 模糊核在 y 方向的标准差范围
blur_y_sigma: [0.1, 10]
# noise_range: [0, 25] # 噪声范围
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40] # 不同插值方法的概率
# use_crf: True # 是否使用crf压缩
# crf_range: [10, 30] # CRF 压缩范围
crf_range: [18, 25]
vcodec: ['libx264'] # 视频编码格式
vcodec_prob: [1] # 视频编码格式的概率
latent_gt_path: ~ # without pre-calculated latent code
# latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'
# data loader
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 10
prefetch_mode: ~
val:
name: VFHQ-Test-50
type: InpaintingDataset
# dataroot_gt: ../VFHQ_Test/VAL_cases
# global_meta_info_file: ./vfhq_val_data_info.txt
# dataroot_meta_info: ./vfhq_val_landmarks
dataroot_gt: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_Test/TEST_DATA
global_meta_info_file: ./vfhq_test.txt
dataroot_meta_info: /sykj_002/datasets/VFHQ/VFHQ_DATAset/VFHQ_Test/vfhq_test_landmarks
io_backend:
type: disk
video_length: 5
scale: 4
need_align: True
normalize: True
interval_list: [1]
random_reverse: False
use_flip: False
use_rot: False
# large degradation in stageII
blur_kernel_size: 21
kernel_list: ['iso', 'aniso'] # 模糊核的类型列表
# kernel_prob: [0.5, 0.5] # 模糊核类型的概率
kernel_prob: [0.7, 0.3]
# blur_x_sigma: [0.2, 3] # 模糊核在 x 方向的标准差范围
blur_x_sigma: [0.1, 10]
# blur_y_sigma: [0.2, 3] # 模糊核在 y 方向的标准差范围
blur_y_sigma: [0.1, 10]
# noise_range: [0, 25] # 噪声范围
noise_range: [0, 10]
resize_prob: [0.20, 0.40, 0.40] # 不同插值方法的概率
# use_crf: True # 是否使用crf压缩
# crf_range: [10, 30] # CRF 压缩范围
crf_range: [18, 25]
vcodec: ['libx264'] # 视频编码格式
vcodec_prob: [1] # 视频编码格式的概率
# data loader
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 1
prefetch_mode: ~
# network structures
network_g:
type: TemporalCodeFormerDirDistMultiScale
dim_embed: 512
n_head: 8
n_layers: 9
codebook_size: 1024
connect_list: ['32', '64', '128', '256']
# fix_modules: ['encoder','quantize', 'fuse_convs_dict', 'feat_emb'] # decoder 放开, generator
fix_modules: []
# vqgan_path: './weights/CodeFormer/vqgan_code1024.pth' # pretrained VQGAN
frame_length: 5
# network_vqgan: # this config is needed if no pre-calculated latent
# type: VQAutoEncoder
# img_size: 512
# nf: 64
# ch_mult: [1, 2, 2, 4, 4, 8]
# quantizer: 'nearest'
# codebook_size: 1024
network_d:
type: VQGANDiscriminator
nc: 3
ndf: 64
n_layers: 4
# path
path:
pretrain_network_g: './weights/CodeFormer/codeformer.pth'
param_key_g: params_ema
strict_load_g: false
pretrain_network_d: './weights/CodeFormer/vqgan_discriminator.pth'
strict_load_d: true
resume_state: ~
# base_lr(4.5e-6)*bach_size(4)
train:
# use_hq_feat_loss: False
# feat_loss_weight: 1.0
cross_entropy_loss: true
entropy_loss_weight: 0.5
fidelity_weight: 0
optim_g:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
optim_d:
type: Adam
lr: !!float 5e-5
weight_decay: 0
betas: [0.9, 0.99]
# scheduler:
# type: MultiStepLR
# milestones: [30000, 45000]
# gamma: 0.5
scheduler:
type: CosineAnnealingRestartLR
periods: [100000]
restart_weights: [1]
eta_min: !!float 2e-5
total_iter: 100000
warmup_iter: -1 # no warm up
ema_decay: 0.997
# training loss
pixel_opt:
type: L1Loss
loss_weight: 1.0
reduction: mean
perceptual_opt:
type: LPIPSLoss
loss_weight: 1.0
use_input_norm: true
range_norm: true
dirichletKL_opt:
type: DirichletKLLoss
loss_weight: 0.00
kl_coef: 0.1
gan_opt:
type: GANLoss
gan_type: hinge
loss_weight: !!float 1.0 # adaptive_weighting
use_adaptive_weight: true
net_g_start_iter: 0
net_d_iters: 1
net_d_start_iter: 6000000000
manual_seed: 0
# validation settings
val:
val_freq: !!float 1000 # no validation
save_img: true
metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
test_y_channel: false
# logging settings
logger:
print_freq: 100
save_checkpoint_freq: !!float 1000
use_tb_logger: true
wandb:
mode: offline
project: codeformer_dirichlet_clip5_bs2_align_nofix_multiscale_inpaint
resume_id: ~
# dist training settings
dist_params:
backend: nccl
port: 29412
find_unused_parameters: false
================================================
FILE: requirements.txt
================================================
addict
future
lmdb
numpy
opencv-python
Pillow
pyyaml
requests
scikit-image
scipy
# tb-nightly
tensorboard
torch>=1.7.1
torchvision
tqdm
yapf
lpips
einops
av
ffmpeg-python
wandb
================================================
FILE: scripts/inference.py
================================================
import os
import cv2
import argparse
import glob
import torch
import numpy as np
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.misc import gpu_is_available, get_device
from scipy.ndimage import gaussian_filter1d
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
from basicsr.utils.video_util import VideoReader, VideoWriter
from einops import rearrange
from basicsr.utils.registry import ARCH_REGISTRY
def interpolate_sequence(sequence):
interpolated_sequence = np.copy(sequence)
missing_indices = np.isnan(sequence)
if np.any(missing_indices):
valid_indices = ~missing_indices
x = np.arange(len(sequence))
# Interpolate missing values using valid data points
interpolated_sequence[missing_indices] = np.interp(
x[missing_indices], x[valid_indices], sequence[valid_indices]
)
return interpolated_sequence
def set_realesrgan():
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
use_half = False
if torch.cuda.is_available(): # set False in CPU/MPS mode
# set False for GPUs that don't support f16
no_half_gpu_list = ["1650", "1660"]
if not True in [
gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list
]:
use_half = True
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="./ckpts/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=args.bg_tile,
tile_pad=40,
pre_pad=0,
half=use_half,
)
if not gpu_is_available(): # CPU
import warnings
warnings.warn(
"Running on CPU now! Make sure your PyTorch version matches your CUDA."
"The unoptimized RealESRGAN is slow on CPU. "
"If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.",
category=RuntimeWarning,
)
return upsampler
if __name__ == "__main__":
device = get_device()
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_path",
type=str,
default="None",
help="Input image, video or folder. Default: inputs/whole_imgs",
)
parser.add_argument(
"-o",
"--output_path",
type=str,
default="results/",
help="Output folder. Default: results/",
)
parser.add_argument(
"--save_video", action="store_true", help="Save output as video. Default: False"
)
parser.add_argument(
"-s",
"--upscale",
type=int,
default=2,
help="The final upsampling scale of the image. Default: 1",
)
parser.add_argument(
"--max_length",
type=int,
default=20,
help="Max length of per sub-clip depending of GPU memory. Default: 20",
)
parser.add_argument(
"--has_aligned",
action="store_true",
help="Input are cropped and aligned faces. Default: False",
)
parser.add_argument(
"--only_center_face",
type=bool,
default=True,
help="Only restore the center face. Default: True",
)
parser.add_argument(
"--draw_box",
action="store_true",
help="Draw the bounding box for the detected faces. Default: False",
)
parser.add_argument(
"--detection_model",
type=str,
default="retinaface_resnet50",
help="Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50",
)
parser.add_argument(
"--bg_upsampler",
type=str,
default="None",
help="Background upsampler. Optional: realesrgan",
)
parser.add_argument(
"--face_upsample",
action="store_true",
help="Face upsampler after enhancement. Default: False",
)
parser.add_argument(
"--bg_tile",
type=int,
default=400,
help="Tile size for background sampler. Default: 400",
)
parser.add_argument(
"--save_video_fps",
type=float,
default=20,
help="Frame rate for saving video. Default: 20",
)
parser.add_argument(
"--ckpt_path", type=str, default="None", help="the loaded ckpt file path"
)
args = parser.parse_args()
input_video = False
ckpt_path = args.ckpt_path
weight_parameter = 1.0
# ------------------ set up background upsampler ------------------
print("------------------ set up background upsampler ------------------")
if args.bg_upsampler == "realesrgan":
bg_upsampler = set_realesrgan()
else:
bg_upsampler = None
# ------------------ set up face upsampler ------------------
if args.face_upsample:
if bg_upsampler is not None:
face_upsampler = bg_upsampler
else:
face_upsampler = set_realesrgan()
else:
face_upsampler = None
os.makedirs(args.output_path, exist_ok=True)
# ------------------ set up restorer -------------------
net = ARCH_REGISTRY.get("TemporalCodeFormerDirDistMultiScale")(
dim_embed=512,
n_head=8,
n_layers=9,
codebook_size=1024,
connect_list=["32", "64", "128", "256"],
frame_length=5,
).to(device)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
# ------------------ set up FaceRestoreHelper -------------------
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
if not args.has_aligned:
print(f"Face detection model: {args.detection_model}")
if bg_upsampler is not None:
print(f"Background upsampling: True. Face upsampling: {args.face_upsample}")
else:
print(f"Background upsampling: False. Face upsampling: {args.face_upsample}")
face_helper = FaceRestoreHelper(
args.upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=args.detection_model,
save_ext="png",
use_parse=True,
device=device,
)
# -------------------- start processing ---------------------
input_img_list = []
restored_img_list = []
if args.input_path.endswith(
("mp4", "mov", "avi", "MP4", "MOV", "AVI")
): # input video path
vidreader = VideoReader(args.input_path)
image = vidreader.get_frame()
while image is not None:
input_img_list.append(image)
image = vidreader.get_frame()
fps = (
vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
)
vidreader.close()
clip_name = os.path.basename(args.input_path)[:-4]
result_root = os.path.join(args.output_path, clip_name)
os.makedirs(result_root, exist_ok=True)
elif os.path.isdir(args.input_path): # input img folder
# scan all the jpg and png images
for img_path in sorted(
glob.glob(os.path.join(args.input_path, "*.[jpJP][pnPN]*[gG]"))
):
input_img_list.append(cv2.imread(img_path))
clip_name = os.path.basename(args.input_path)
result_root = os.path.join(args.output_path, clip_name)
os.makedirs(result_root, exist_ok=True)
else:
raise TypeError(f"Unrecognized type of input video {args.input_path}.")
if len(input_img_list) == 0:
raise FileNotFoundError(
"No input image/video is found...\n"
"\tNote that --input_path for video should end with .mp4|.mov|.avi"
)
if not args.has_aligned:
# Smoothing aligned landmarks
print("Detecting keypoints and smooth alignment ...")
raw_landmarks = []
for i, img in enumerate(input_img_list):
# clean all the intermediate results to process the next image
face_helper.clean_all()
face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=args.only_center_face,
resize=640,
eye_dist_threshold=5,
only_keep_largest=True,
)
if num_det_faces == 1:
raw_landmarks.append(face_helper.all_landmarks_5[0].reshape((10,)))
elif num_det_faces == 0:
raw_landmarks.append(np.array([np.nan] * 10))
raw_landmarks = np.array(raw_landmarks)
for i in range(10):
raw_landmarks[:, i] = interpolate_sequence(raw_landmarks[:, i])
video_length = len(input_img_list)
avg_landmarks = gaussian_filter1d(raw_landmarks, 5, axis=0).reshape(
video_length, 5, 2
)
# Pack cropped faces.
cropped_faces = []
for i, img in enumerate(input_img_list):
if not args.has_aligned:
face_helper.clean_all()
face_helper.read_image(img)
face_helper.all_landmarks_5 = [avg_landmarks[i]]
face_helper.align_warp_face()
else:
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=10)
if face_helper.is_gray:
print("Grayscale input: True")
face_helper.cropped_faces = [img]
cropped_face_t = img2tensor(
face_helper.cropped_faces[0] / 255.0, bgr2rgb=True, float32=True
)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_faces.append(cropped_face_t)
cropped_faces = torch.stack(cropped_faces, dim=0).unsqueeze(0).to(device)
print("Restoring faces ...")
with torch.no_grad():
video_length = cropped_faces.shape[1]
output = []
for start_idx in range(0, video_length):
pre_length = args.max_length // 2
post_length = args.max_length - pre_length - 1
padding_begin_idx = start_idx - pre_length
padding_end_idx = start_idx + post_length
if padding_begin_idx < 0:
pre_padding = torch.zeros(
(
cropped_faces.shape[0],
-padding_begin_idx,
*cropped_faces.shape[2:],
),
dtype=cropped_faces.dtype,
device=cropped_faces.device,
)
pre_padding = pre_padding + cropped_faces[:, 0:1]
small_clip = torch.cat(
[pre_padding, cropped_faces[:, : padding_end_idx + 1, ...]], dim=1
)
elif padding_end_idx >= video_length:
post_padding = torch.zeros(
(
cropped_faces.shape[0],
padding_end_idx - video_length + 1,
*cropped_faces.shape[2:],
),
dtype=cropped_faces.dtype,
device=cropped_faces.device,
)
post_padding = post_padding + cropped_faces[:, -1:]
small_clip = torch.cat(
[cropped_faces[:, padding_begin_idx:, ...], post_padding], dim=1
)
else:
small_clip = cropped_faces[
:, padding_begin_idx : padding_end_idx + 1, ...
]
small_clip = rearrange(
small_clip, "b t c h w -> (b t) c h w", t=args.max_length
)
bt = small_clip.shape[0]
res, _, _ = net(
small_clip, w=weight_parameter
)
res = rearrange(res, "(b t) c h w -> b t c h w", t=args.max_length)
res = res[:, pre_length : pre_length + 1, ...]
output.append(res)
output = torch.cat(output, dim=1).squeeze(0)
assert output.shape[0] == video_length, "Differer number of frames"
restored_faces = [tensor2img(x, rgb2bgr=True, min_max=(-1, 1)) for x in output]
del output
torch.cuda.empty_cache()
print("Pasting faces back ...")
for i, img in enumerate(input_img_list):
# clean all the intermediate results to process the next image
face_helper.clean_all()
if args.has_aligned:
# the input faces are already cropped and aligned
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=10)
if face_helper.is_gray:
print("Grayscale input: True")
face_helper.cropped_faces = [img]
else:
# align and warp each face
face_helper.read_image(img)
face_helper.all_landmarks_5 = [avg_landmarks[i]]
face_helper.align_warp_face()
face_helper.add_restored_face(restored_faces[i].astype("uint8"))
# paste_back
if not args.has_aligned:
# upsample the background
if bg_upsampler is not None:
# Now only support RealESRGAN for upsampling background
bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
else:
bg_img = None
face_helper.get_inverse_affine(None)
# paste each restored face to the input image
if args.face_upsample and face_upsampler is not None:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img,
draw_box=args.draw_box,
face_upsampler=face_upsampler,
)
else:
restored_img = face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=args.draw_box
)
restored_img_list.append(restored_img)
# save faces
save_face_name = f"{i:08d}.png"
for face_idx, (cropped_face, restored_face) in enumerate(
zip(face_helper.cropped_faces, face_helper.restored_faces)
):
# save cropped face
if not args.has_aligned:
save_crop_path = os.path.join(
result_root, "cropped_faces", save_face_name
)
imwrite(cropped_face, save_crop_path)
# save restored face
save_restore_path = os.path.join(
result_root, "restored_faces", save_face_name
)
imwrite(restored_face, save_restore_path)
# save restored img
if not args.has_aligned and restored_img is not None:
save_restore_path = os.path.join(
result_root, "final_results", save_face_name
)
imwrite(restored_img, save_restore_path)
# save enhanced video
if args.save_video:
print("Saving video ...")
# load images
video_frames = []
if not args.has_aligned:
img_list = sorted(
glob.glob(os.path.join(result_root, "final_results", "*.[jp][pn]g"))
)
else:
img_list = sorted(
glob.glob(os.path.join(result_root, "restored_faces", "*.[jp][pn]g"))
)
for img_path in img_list:
img = cv2.imread(img_path)
video_frames.append(img)
height, width = video_frames[0].shape[:2]
save_restore_path = os.path.join(args.output_path, f"{clip_name}.mp4")
vidwriter = VideoWriter(
save_restore_path, height, width, args.save_video_fps, audio=None
)
for f in video_frames:
vidwriter.write_frame(f)
vidwriter.close()
print(f"\nAll results are saved in {result_root}")
================================================
FILE: scripts/inference_color_and_inpainting.py
================================================
import os
import cv2
import argparse
import glob
import torch
import numpy as np
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.misc import gpu_is_available, get_device
from scipy.ndimage import gaussian_filter1d
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
from basicsr.utils.video_util import VideoReader, VideoWriter
from einops import rearrange
from basicsr.utils.registry import ARCH_REGISTRY
def interpolate_sequence(sequence):
interpolated_sequence = np.copy(sequence)
missing_indices = np.isnan(sequence)
if np.any(missing_indices):
valid_indices = ~missing_indices
x = np.arange(len(sequence))
# Interpolate missing values using valid data points
interpolated_sequence[missing_indices] = np.interp(
x[missing_indices], x[valid_indices], sequence[valid_indices]
)
return interpolated_sequence
def set_realesrgan():
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer
use_half = False
if torch.cuda.is_available(): # set False in CPU/MPS mode
# set False for GPUs that don't support f16
no_half_gpu_list = ["1650", "1660"]
if not True in [
gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list
]:
use_half = True
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="./ckpts/realesrgan/RealESRGAN_x2plus.pth",
model=model,
tile=args.bg_tile,
tile_pad=40,
pre_pad=0,
half=use_half,
)
if not gpu_is_available(): # CPU
import warnings
warnings.warn(
"Running on CPU now! Make sure your PyTorch version matches your CUDA."
"The unoptimized RealESRGAN is slow on CPU. "
"If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.",
category=RuntimeWarning,
)
return upsampler
if __name__ == "__main__":
device = get_device()
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input_path",
type=str,
default="None",
help="Input image, video or folder. Default: inputs/whole_imgs",
)
parser.add_argument(
"-o",
"--output_path",
type=str,
default="results/",
help="Output folder. Default: results/",
)
parser.add_argument(
"--save_video", action="store_true", help="Save output as video. Default: False"
)
parser.add_argument(
"-s",
"--upscale",
type=int,
default=2,
help="The final upsampling scale of the image. Default: 1",
)
parser.add_argument(
"--max_length",
type=int,
default=20,
help="Max length of per sub-clip depending of GPU memory. Default: 20",
)
parser.add_argument(
"--has_aligned",
action="store_true",
help="Input are cropped and aligned faces. Default: False",
)
parser.add_argument(
"--only_center_face",
type=bool,
default=True,
help="Only restore the center face. Default: True",
)
parser.add_argument(
"--draw_box",
action="store_true",
help="Draw the bounding box for the detected faces. Default: False",
)
parser.add_argument(
"--detection_model",
type=str,
default="retinaface_resnet50",
help="Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50",
)
parser.add_argument(
"--bg_upsampler",
type=str,
default="None",
help="Background upsampler. Optional: realesrgan",
)
parser.add_argument(
"--face_upsample",
action="store_true",
help="Face upsampler after enhancement. Default: False",
)
parser.add_argument(
"--bg_tile",
type=int,
default=400,
help="Tile size for background sampler. Default: 400",
)
parser.add_argument(
"--save_video_fps",
type=float,
default=20,
help="Frame rate for saving video. Default: 20",
)
parser.add_argument(
"--ckpt_path", type=str, default="None", help="the loaded ckpt file path"
)
args = parser.parse_args()
input_video = False
ckpt_path = args.ckpt_path
weight_parameter = 1.0
# ------------------ set up background upsampler ------------------
print("------------------ set up background upsampler ------------------")
if args.bg_upsampler == "realesrgan":
bg_upsampler = set_realesrgan()
else:
bg_upsampler = None
# ------------------ set up face upsampler ------------------
if args.face_upsample:
if bg_upsampler is not None:
face_upsampler = bg_upsampler
else:
face_upsampler = set_realesrgan()
else:
face_upsampler = None
# ------------------ set up restorer -------------------
net = ARCH_REGISTRY.get("TemporalCodeFormerDirDistMultiScale")(
dim_embed=512,
n_head=8,
n_layers=9,
codebook_size=1024,
connect_list=["32", "64", "128", "256"],
frame_length=5,
).to(device)
checkpoint = torch.load(ckpt_path)["params_ema"]
net.load_state_dict(checkpoint)
net.eval()
# ------------------ set up FaceRestoreHelper -------------------
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
if not args.has_aligned:
print(f"Face detection model: {args.detection_model}")
if bg_upsampler is not None:
print(f"Background upsampling: True. Face upsampling: {args.face_upsample}")
else:
print(f"Background upsampling: False. Face upsampling: {args.face_upsample}")
face_helper = FaceRestoreHelper(
args.upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=args.detection_model,
save_ext="png",
use_parse=True,
device=device,
)
# -------------------- start processing ---------------------
input_img_list = []
restored_img_list = []
if args.input_path.endswith(
("mp4", "mov", "avi", "MP4", "MOV", "AVI")
): # input video path
vidreader = VideoReader(args.input_path)
image = vidreader.get_frame()
while image is not None:
input_img_list.append(image)
image = vidreader.get_frame()
fps = (
vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
)
vidreader.close()
clip_name = os.path.basename(args.input_path)[:-4]
result_root = os.path.join(args.output_path, clip_name)
os.makedirs(result_root, exist_ok=True)
elif os.path.isdir(args.input_path): # input img folder
# scan all the jpg and png images
for img_path in sorted(
glob.glob(os.path.join(args.input_path, "*.[jpJP][pnPN]*[gG]"))
):
input_img_list.append(cv2.imread(img_path))
clip_name = os.path.basename(args.input_path)
result_root = os.path.join(args.output_path, clip_name)
os.makedirs(result_root, exist_ok=True)
else:
raise TypeError(f"Unrecognized type of input video {args.input_path}.")
if len(input_img_list) == 0:
raise FileNotFoundError(
"No input image/video is found...\n"
"\tNote that --input_path for video should end with .mp4|.mov|.avi"
)
if not args.has_aligned:
# Smoothing aligned landmarks
print("Detecting keypoints and smooth alignment ...")
raw_landmarks = []
for i, img in enumerate(input_img_list):
# clean all the intermediate results to process the next image
face_helper.clean_all()
face_helper.read_image(img)
# get face landmarks for each face
num_det_faces = face_helper.get_face_landmarks_5(
only_center_face=args.only_center_face,
resize=640,
eye_dist_threshold=5,
only_keep_largest=True,
)
if num_det_faces == 1:
raw_landmarks.append(face_helper.all_landmarks_5[0].reshape((10,)))
elif num_det_faces == 0:
raw_landmarks.append(np.array([np.nan] * 10))
raw_landmarks = np.array(raw_landmarks)
for i in range(10):
raw_landmarks[:, i] = interpolate_sequence(raw_landmarks[:, i])
video_length = len(input_img_list)
avg_landmarks = gaussian_filter1d(raw_landmarks, 5, axis=0).reshape(
video_length, 5, 2
)
# Pack cropped faces.
cropped_faces = []
for i, img in enumerate(input_img_list):
if not args.has_aligned:
face_helper.clean_all()
face_helper.read_image(img)
face_helper.all_landmarks_5 = [avg_landmarks[i]]
face_helper.align_warp_face()
else:
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=10)
# if face_helper.is_gray:
# print("Grayscale input: True")
face_helper.cropped_faces = [img]
cropped_face_t = img2tensor(
face_helper.cropped_faces[0] / 255.0, bgr2rgb=True, float32=True
)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_faces.append(cropped_face_t)
cropped_faces = torch.stack(cropped_faces, dim=0).unsqueeze(0).to(device)
print("Restoring faces ...")
with torch.no_grad():
video_length = cropped_faces.shape[1]
output = []
for start_idx in range(0, video_length):
pre_length = args.max_length // 2
post_length = args.max_length - pre_length - 1
padding_begin_idx = start_idx - pre_length
padding_end_idx = start_idx + post_length
if padding_begin_idx < 0:
pre_padding = torch.zeros(
(
cropped_faces.shape[0],
-padding_begin_idx,
*cropped_faces.shape[2:],
),
dtype=cropped_faces.dtype,
device=cropped_faces.device,
)
pre_padding = pre_padding + cropped_faces[:, 0:1]
small_clip = torch.cat(
[pre_padding, cropped_faces[:, : padding_end_idx + 1, ...]], dim=1
)
elif padding_end_idx >= video_length:
post_padding = torch.zeros(
(
cropped_faces.shape[0],
padding_end_idx - video_length + 1,
*cropped_faces.shape[2:],
),
dtype=cropped_faces.dtype,
device=cropped_faces.device,
)
post_padding = post_padding + cropped_faces[:, -1:]
small_clip = torch.cat(
[cropped_faces[:, padding_begin_idx:, ...], post_padding], dim=1
)
else:
small_clip = cropped_faces[
:, padding_begin_idx : padding_end_idx + 1, ...
]
small_clip = rearrange(
small_clip, "b t c h w -> (b t) c h w", t=args.max_length
)
bt = small_clip.shape[0]
res, _, _ = net(small_clip, w=weight_parameter)
res = rearrange(res, "(b t) c h w -> b t c h w", t=args.max_length)
res = res[:, pre_length : pre_length + 1, ...]
output.append(res)
output = torch.cat(output, dim=1).squeeze(0)
assert output.shape[0] == video_length, "Differer number of frames"
restored_faces = [tensor2img(x, rgb2bgr=True, min_max=(-1, 1)) for x in output]
del output
torch.cuda.empty_cache()
print("Saving result ...")
output_path = result_root
os.makedirs(output_path, mode=0o777, exist_ok=True)
if args.save_video:
writer = cv2.VideoWriter(
f"{output_path}.mp4",
fourcc=cv2.VideoWriter_fourcc(*"mp4v"),
fps=args.save_video_fps,
frameSize=(512, 512),
)
for idx, restored_img in enumerate(restored_faces):
img_abs_path = os.path.join(output_path, str(idx).zfill(8) + ".png")
cv2.imwrite(img_abs_path, restored_img, [cv2.IMWRITE_PNG_COMPRESSION, 0])
if args.save_video:
writer.write(restored_img)
if args.save_video:
writer.release()
================================================
FILE: scripts/warp_images.py
================================================
import os
import cv2
import argparse
import glob
import torch
import pdb
import numpy as np
from tqdm import tqdm
from torchvision.transforms.functional import normalize
from basicsr.utils import imwrite, img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import gpu_is_available, get_device
from scipy.ndimage import gaussian_filter1d
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
from basicsr.utils.video_util import VideoReader, VideoWriter
from einops import rearrange
from utils import TDCF_OPT, TCFDD_OPT
from basicsr.utils.registry import ARCH_REGISTRY
def interpolate_sequence(sequence):
interpolated_sequence = np.copy(sequence)
missing_indices = np.isnan(sequence)
if np.any(missing_indices):
valid_indices = ~missing_indices
x = np.arange(len(sequence))
# Interpolate missing values using valid data points
interpolated_sequence[missing_indices] = np.interp(
x[missing_indices], x[valid_indices], sequence[valid_indices])
return interpolated_sequence
def process_single(args, face_helper, input_path, ldmk_folder_path):
input_img_list = []
if input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
vidreader = VideoReader(input_path)
image = vidreader.get_frame()
while image is not None:
input_img_list.append(image)
image = vidreader.get_frame()
fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
vidreader.close()
clip_name = os.path.basename(input_path)[:-4]
result_root = os.path.join(args.output_path, clip_name)
elif os.path.isdir(args.input_path): # input img folder
# scan all the jpg and png images
for img_path in sorted(glob.glob(os.path.join(input_path, '*.[jpJP][pnPN]*[gG]'))):
input_img_list.append(cv2.imread(img_path))
clip_name = os.path.basename(input_path)
result_root = os.path.join(args.output_path, clip_name)
else:
raise TypeError(f'Unrecognized type of input video {input_path}.')
if len(input_img_list) == 0:
raise FileNotFoundError('No input image/video is found...\n'
'\tNote that --input_path for video should end with .mp4|.mov|.avi')
# Smoothing aligned landmarks
print('Detecting keypoints and smooth alignment ...')
avg_landmarks = []
with open(f"{ldmk_folder_path}/{clip_name}.txt", "r") as f:
for line in f.readlines():
line = line.strip().split()
landmark = np.array([float(_) for _ in line]).reshape(5, 2)
avg_landmarks.append(landmark)
# Save cropped faces.
output_path = os.path.join(args.output_path, f'{clip_name}')
os.makedirs(output_path, mode=0o777, exist_ok=True)
if args.save_video:
writer = cv2.VideoWriter(os.path.join(output_path, f'{clip_name}.mp4'),
fourcc=cv2.VideoWriter_fourcc(*'mp4v'),
fps=args.save_video_fps,
frameSize=(512, 512))
for idx, img in enumerate(input_img_list):
face_helper.clean_all()
face_helper.read_image(img)
face_helper.all_landmarks_5 = [avg_landmarks[idx]]
face_helper.align_warp_face()
img_abs_path = os.path.join(output_path, str(idx).zfill(8)+'.png')
cv2.imwrite(img_abs_path, face_helper.cropped_faces[0], [cv2.IMWRITE_PNG_COMPRESSION, 0])
if args.save_video:
writer.write(face_helper.cropped_faces[0])
if args.save_video:
writer.release()
print(f'All results are saved in {result_root}')
if __name__ == '__main__':
device = get_device()
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input_path', type=str, default='../dataset/real_LQ/',
help='no warped images')
parser.add_argument('-o', '--output_path', type=str, default='results/',
help='Output folder. Default: results/')
parser.add_argument('-l', '--ldmk_folder_path', type=str, required=True,
help='landmarks info folder.')
parser.add_argument('--save_video', action='store_true',
help='Save output as video. Default: False')
parser.add_argument('-s', '--upscale', type=int, default=1,
help='The final upsampling scale of the image. Default: 1')
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
Default: retinaface_resnet50')
parser.add_argument('--bg_tile', type=int, default=0,
help='Tile size for background sampler. Default: 400')
parser.add_argument('--save_video_fps', type=float, default=24,
help='Frame rate for saving video. Default: 20')
args = parser.parse_args()
# ------------------ set up FaceRestoreHelper -------------------
face_helper = FaceRestoreHelper(
args.upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=args.detection_model,
save_ext='png',
use_parse=False,
device=device)
for _, clip_name in enumerate(tqdm(os.listdir(args.input_path))):
process_single(args,
face_helper,
os.path.join(args.input_path, clip_name), args.ldmk_folder_path)
================================================
FILE: train.sh
================================================
CUDA_VISIBLE_DEVICES=0 torchrun \
--nproc_per_node=1 --master_port=29597 \
basicsr/train.py \
-opt options/clip5_bs2_512_align_nofix_multiscale.yaml \
--launcher pytorch