Repository: greatlog/UnpairedSR Branch: master Commit: 771312ee2bd4 Files: 308 Total size: 1.2 MB Directory structure: gitextract_gen0q67r/ ├── .gitignore ├── README.md └── codes/ ├── config/ │ ├── BSRGAN/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── sr_model.py │ │ ├── options/ │ │ │ └── test/ │ │ │ ├── 2017Track2_2020Track1.yml │ │ │ ├── 2018Track2_2018Track4.yml │ │ │ └── 2020Track2.yml │ │ ├── test.py │ │ └── train.py │ ├── Bicubic/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── bicubic.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── sr_model.py │ │ ├── options/ │ │ │ └── test/ │ │ │ ├── 2017Track2_2020Track1.yml │ │ │ ├── 2018Track2_2020Track4.yml │ │ │ └── 2020Track2.yml │ │ ├── test.py │ │ └── train.py │ ├── Bulat/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── deg_arch.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── deg_sr_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ ├── 2017Track2.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── train/ │ │ │ └── psnr/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ └── 2020Track1.yml │ │ ├── test.py │ │ └── train.py │ ├── CinGAN/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── cingan_model.py │ │ │ └── trans_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ └── sr/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── train/ │ │ │ ├── sr/ │ │ │ │ ├── 2017Track2.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── trans/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ └── 2020Track1.yml │ │ ├── test.py │ │ └── train.py │ ├── CycleSR/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── cyclegan_model.py │ │ │ └── cyclesr_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ └── sr/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ ├── 2020Track1.yml │ │ │ │ └── 2020Track1_percep.yml │ │ │ └── train/ │ │ │ ├── sr/ │ │ │ │ └── psnr/ │ │ │ │ ├── 2017Track2.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── trans/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ └── 2020Track1.yml │ │ ├── test.py │ │ └── train.py │ ├── DSGANSR/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── deg_arch.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── deg_sr_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── train/ │ │ │ ├── deg/ │ │ │ │ ├── 2017Track2.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── sr/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ └── 2020Track1.yml │ │ ├── test.py │ │ └── train.py │ ├── EDSR/ │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── bicubic.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── sr_model.py │ │ ├── options/ │ │ │ └── test/ │ │ │ ├── 2017Track2_2020Track1.yml │ │ │ ├── 2018Track2_2020Track4.yml │ │ │ └── 2020Track2.yml │ │ ├── test.py │ │ └── train.py │ ├── Maeda/ │ │ ├── README.md │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ ├── translator.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── pseudo_supervision_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ ├── 2017Track2.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ └── 2020Track1.yml │ │ │ └── train/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ └── 2020Track1.yml │ │ ├── test.py │ │ └── train.py │ ├── PDM-SR/ │ │ ├── archs/ │ │ │ ├── __init__.py │ │ │ ├── deg_arch.py │ │ │ ├── discriminator.py │ │ │ ├── edsr.py │ │ │ ├── loss.py │ │ │ ├── lr_scheduler.py │ │ │ ├── module_util.py │ │ │ ├── rcan.py │ │ │ ├── rrdb.py │ │ │ ├── srresnet.py │ │ │ └── vgg.py │ │ ├── count_flops.py │ │ ├── inference.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ └── deg_sr_model.py │ │ ├── options/ │ │ │ ├── test/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ ├── 2020Track1.yml │ │ │ │ └── 2020Track2.yml │ │ │ └── train/ │ │ │ ├── deg/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ ├── 2020Track1.yml │ │ │ │ └── 2020Track2.yml │ │ │ ├── percep/ │ │ │ │ ├── 2017Track1.yml │ │ │ │ ├── 2018Track2.yml │ │ │ │ ├── 2018Track4.yml │ │ │ │ ├── 2020Track1.yml │ │ │ │ └── 2020Track2.yml │ │ │ └── psnr/ │ │ │ ├── 2017Track2.yml │ │ │ ├── 2018Track2.yml │ │ │ ├── 2018Track4.yml │ │ │ ├── 2020Track1.yml │ │ │ └── 2020Track2.yml │ │ ├── test.py │ │ └── train.py │ └── RealESRGAN/ │ ├── README.md │ ├── archs/ │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── edsr.py │ │ ├── loss.py │ │ ├── lr_scheduler.py │ │ ├── module_util.py │ │ ├── rcan.py │ │ ├── rrdb.py │ │ ├── srresnet.py │ │ ├── translator.py │ │ └── vgg.py │ ├── count_flops.py │ ├── inference.py │ ├── models/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ └── sr_model.py │ ├── options/ │ │ └── test/ │ │ ├── 2017Track2_2020Track1.yml │ │ ├── 2018Track2_2018Track4.yml │ │ └── 2020Track2.yml │ ├── test.py │ └── train.py ├── data/ │ ├── __init__.py │ ├── data_sampler.py │ ├── debug_dataset.py │ ├── fixed_image_dataset.py │ ├── paired_ref_dataset.py │ ├── paried_dataset.py │ ├── single_dataset.py │ ├── single_image_dataset.py │ └── unpaired_dataset.py ├── metrics/ │ ├── __init__.py │ ├── best_psnr.py │ ├── measure.py │ ├── psnr.py │ └── ssim.py ├── scripts/ │ ├── create_lmdb.py │ ├── extract_subimgs_single.py │ ├── generate_mod_LR_bic.m │ ├── generate_mod_LR_bic.py │ ├── generate_mod_blur_LR_bic.py │ └── test_imgs.py └── utils/ ├── __init__.py ├── data_utils.py ├── deg_utils.py ├── file_utils.py ├── img_utils.py ├── option.py ├── registry.py └── resize_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ experiments/ results/ result/ result log/ log data_samples/ checkpoints/ *.pkl *.pt *.pth *.jpg *.png *.state *.event ================================================ FILE: README.md ================================================ This is an offical implementation of the CVPR2022's paper [Learning the Degradation Distribution for Blind Image Super-Resolution](https://arxiv.org/abs/2203.04962). This repo also contains the implementations of many other blind SR methods in [config](codes/config/), including CinGAN, CycleSR, DSGAN-SR, etc. If you find this repo useful for your work, please cite our paper: ``` @inproceedings{PDMSR, title={Learning the Degradation Distribution for Blind Image Super-Resolution}, author={Zhengxiong Luo and Yan Huang and and Shang Li and Liang Wang and Tieniu Tan}, booktitle={CVPR}, year={2022} } ``` The codes are built on the basis of [BasicSR](https://github.com/xinntao/BasicSR). ## Dependences 1. lpips (pip install --user lpips) 2. matlab (to support the evaluation of NIQE). The details about installing a matlab API for python can refer to [here](https://ww2.mathworks.cn/help/matlab/matlab_external/install-the-matlab-engine-for-python.html) ## Datasets The datasets in NTIRE2017 and NTIRE2018 can be downloaded from [here](https://data.vision.ee.ethz.ch/cvl/DIV2K/). The datasets in NTIRE2020 can be downloaded from the [competition site](https://competitions.codalab.org/competitions/22220). ## Start up We provide the checkpoints in in [Google drive](https://drive.google.com/drive/folders/1bVMGaGF7yLyQhM0xmRVMD2SolOtgLvxO?usp=sharing) and [BaiduYun](https://pan.baidu.com/s/1BcYcX0yCS-3-6XqT4BgYAQ?pwd=ovmw)(password: ovmw). Please download them into the [checkpoints](checkpoints/) directoty. To get a quick start: ```bash cd codes/config/PDM-SR/ python3 inference.py --opt options/test/2020Track2.yml ``` ================================================ FILE: codes/config/BSRGAN/README.md ================================================ This repo currently only supports the test of [BSRGAN](https://arxiv.org/abs/2103.14006). The training related codes may be added in the future. ================================================ FILE: codes/config/BSRGAN/archs/__init__.py ================================================ import importlib import os import os.path as osp from utils.registry import ARCH_REGISTRY, LOSS_REGISTRY, LR_SCHEDULER_REGISTRY arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(arch_folder) if v.endswith(".py") ] # import all the arch modules _arch_modules = [ importlib.import_module(f"archs.{file_name}") for file_name in arch_filenames ] def build_network(net_opt): which_network = net_opt["which_network"] net = ARCH_REGISTRY.get(which_network)(**net_opt["setting"]) return net def build_loss(loss_opt): loss_type = loss_opt.pop("type") loss = LOSS_REGISTRY.get(loss_type)(**loss_opt) return loss def build_scheduler(optimizer, scheduler_opt): scheduler_type = scheduler_opt.pop("type") scheduler = LR_SCHEDULER_REGISTRY.get(scheduler_type)(optimizer, **scheduler_opt) return scheduler ================================================ FILE: codes/config/BSRGAN/archs/discriminator.py ================================================ import torch import torch.nn as nn import torchvision import functools from utils.registry import ARCH_REGISTRY @ARCH_REGISTRY.register() class DiscriminatorVGG128(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512 * 4 * 4, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class DiscriminatorVGG32(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class PatchGANDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, in_c, nf, nb, stride=1, norm_layer=nn.InstanceNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 3 padw = 1 sequence = [ nn.Conv2d(in_c, nf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, nb): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=stride, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult nf_mult = min(2 ** nb, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d(nf * nf_mult, nf, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) ================================================ FILE: codes/config/BSRGAN/archs/edsr.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, ): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer @ARCH_REGISTRY.register() class EDSR(nn.Module): def __init__(self, nb, nf, res_scale=0.1, upscale=4, conv=default_conv): super(EDSR, self).__init__() n_resblocks = nb n_feats = nf kernel_size = 3 scale = upscale act = nn.ReLU(True) # url_name = 'r{}f{}x{}'.format(nb, nf, upscale) # if url_name in url: # self.url = url[url_name] # else: # self.url = None self.sub_mean = MeanShift(255.0, sign=-1) self.add_mean = MeanShift(255.0, sign=1) # define head module m_head = [conv(3, n_feats, kernel_size)] # define body module m_body = [ ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x * 255.0) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) / 255.0 return x ================================================ FILE: codes/config/BSRGAN/archs/loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import lpips as lp from utils.registry import LOSS_REGISTRY from .vgg import VGGFeatureExtractor @LOSS_REGISTRY.register() class GaussGuided(nn.Module): def __init__(self, ksize, sigma): super().__init__() ax = torch.arange(0, ksize) - ksize//2 xx, yy = torch.meshgrid(ax, ax) dis = (xx ** 2 + yy ** 2) dis = torch.exp(-dis / sigma ** 2) dis = dis / dis.sum() self.register_buffer("gauss", dis.view(1, ksize**2, 1, 1)) def forward(self, kernel): return F.mse_loss(self.gauss, kernel) @LOSS_REGISTRY.register() class PerceptualLossLPIPS(nn.Module): def __init__(self, net="alex", normalize=True): super().__init__() self.fn = lp.LPIPS(net=net, spatial=True) for p in self.fn.parameters(): p.requires_grad = False self.normalize = normalize def forward(self, res, ref): return self.fn(res, ref, normalize=self.normalize).mean(), None @LOSS_REGISTRY.register() class MSELoss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.mse_loss(res, ref) @LOSS_REGISTRY.register() class L1Loss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.l1_loss(res, ref) @LOSS_REGISTRY.register() class GANLoss(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == "vanilla": self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == "lsgan": self.loss = nn.MSELoss() elif self.gan_type == "wgan": self.loss = self._wgan_loss elif self.gan_type == "wgan_softplus": self.loss = self._wgan_softplus_loss elif self.gan_type == "hinge": self.loss = nn.ReLU() else: raise NotImplementedError(f"GAN type {self.gan_type} is not implemented.") def _wgan_loss(self, input, target): """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean() def _wgan_softplus_loss(self, input, target): """wgan loss with soft plus. softplus is a smooth approximation to the ReLU function. In StyleGAN2, it is called: Logistic loss for discriminator; Non-saturating loss for generator. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return F.softplus(-input).mean() if target else F.softplus(input).mean() def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, return Tensor. """ if self.gan_type in ["wgan", "wgan_softplus"]: return target_is_real target_val = self.real_label_val if target_is_real else self.fake_label_val return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == "hinge": if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) return loss @LOSS_REGISTRY.register() class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. Args: layer_weights (dict): The weight for each layer of vgg feature. Here is an example: {'conv5_4': 1.}, which means the conv5_4 feature layer (before relu5_4) will be extracted with weight 1.0 in calculting losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/BSRGAN/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/BSRGAN/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/BSRGAN/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/BSRGAN/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/BSRGAN/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/BSRGAN/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, scale=4, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(in_nc, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(in_nc, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, out_nc, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.head(x) f = self.body(x) x = f + x x = self.tail(x) return x ================================================ FILE: codes/config/BSRGAN/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/BSRGAN/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/BSRGAN/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/BSRGAN/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/BSRGAN/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/BSRGAN/models/sr_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class SRModel(BaseModel): def __init__(self, opt): super().__init__(opt) self.data_names = ["lr", "hr"] self.network_names = ["netSR"] self.networks = {} self.loss_names = ["sr_adv", "sr_pix", "sr_percep"] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: # setup loss, optimizers, schedulers self.setup_train(opt["train"]) def feed_data(self, data): self.lr = data["src"].to(self.device) self.hr = data["tgt"].to(self.device) def forward(self): self.sr = self.netSR(self.lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() l_sr = 0 sr_pix = self.losses["sr_pix"](self.hr, self.sr) loss_dict["sr_pix"] = sr_pix l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], False) sr_adv_g = self.calculate_rgan_loss_G( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_g"] = sr_adv_g l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.hr, self.sr) loss_dict["sr_percep"] = sr_percep if sr_style is not None: loss_dict["sr_style"] = sr_style l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], True) sr_adv_d = self.calculate_rgan_loss_D( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_d"] = sr_adv_d self.optimizers["netD"].zero_grad() sr_adv_d.backward() self.optimizers["netD"].step() self.log_dict = self.reduce_loss_dict(loss_dict) def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data, crop_size=None): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): if crop_size is None: self.fake_real_hr = self.netSR(self.real_lr) else: self.fake_real_hr = self.crop_test(self.real_lr, crop_size) self.netSR.train() def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/BSRGAN/options/test/2017Track2_2020Track1.yml ================================================ #### general settings name: 2017Track2_2020Track1 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [6] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/BSRGAN/BSRGAN.pth strict_load: true ================================================ FILE: codes/config/BSRGAN/options/test/2018Track2_2018Track4.yml ================================================ #### general settings name: 2018Track2_2018Track4 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [6] metrics: [best_psnr, best_ssim, best_lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/BSRGAN/BSRGAN.pth strict_load: true ================================================ FILE: codes/config/BSRGAN/options/test/2020Track2.yml ================================================ #### general settings name: 2020Track2 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [0] metrics: [niqe, piqe, brisque] datasets: test1: name: 2020Track2 mode: SingleDataset data_type: lmdb dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/BSRGAN/BSRGAN.pth strict_load: true ================================================ FILE: codes/config/BSRGAN/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: gt_img = None cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None cropped_gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/BSRGAN/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/Bicubic/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/Bicubic/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/Bicubic/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/Bicubic/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/Bicubic/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/Bicubic/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/Bicubic/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/Bicubic/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/Bicubic/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/Bicubic/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/Bicubic/models/sr_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class SRModel(BaseModel): def __init__(self, opt): super().__init__(opt) self.data_names = ["lr", "hr"] self.network_names = ["netSR"] self.networks = {} self.loss_names = ["sr_adv", "sr_pix", "sr_percep"] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: # setup loss, optimizers, schedulers self.setup_train(opt["train"]) def feed_data(self, data): self.lr = data["src"].to(self.device) self.hr = data["tgt"].to(self.device) def forward(self): self.sr = self.netSR(self.lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() l_sr = 0 sr_pix = self.losses["sr_pix"](self.hr, self.sr) loss_dict["sr_pix"] = sr_pix l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], False) sr_adv_g = self.calculate_rgan_loss_G( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_g"] = sr_adv_g l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.hr, self.sr) loss_dict["sr_percep"] = sr_percep if sr_style is not None: loss_dict["sr_style"] = sr_style l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], True) sr_adv_d = self.calculate_rgan_loss_D( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_d"] = sr_adv_d self.optimizers["netD"].zero_grad() sr_adv_d.backward() self.optimizers["netD"].step() self.log_dict = self.reduce_loss_dict(loss_dict) def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data, crop_size=None): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): if crop_size is None: self.fake_real_hr = self.netSR(self.real_lr) else: self.fake_real_hr = self.crop_test(self.real_lr, crop_size) self.netSR.train() def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/Bicubic/options/test/2017Track2_2020Track1.yml ================================================ #### general settings name: Bicubic_2017Track2_2020Track1 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: BicuBic setting: upscale: 4 pretrain: path: ~ strict_load: true ================================================ FILE: codes/config/Bicubic/options/test/2018Track2_2020Track4.yml ================================================ #### general settings name: Bicubic_2018Track2_2018Track4 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: BicuBic setting: upscale: 4 pretrain: path: ~ strict_load: true ================================================ FILE: codes/config/Bicubic/options/test/2020Track2.yml ================================================ #### general settings name: 2020Track2 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [niqe, piqe, brisque] datasets: test1: name: 2020Track2 mode: SingleDataset data_type: lmdb dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test.lmdb #### network structures networks: netSR: which_network: BicuBic setting: upscale: 4 pretrain: path: ~ strict_load: true ================================================ FILE: codes/config/Bicubic/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: gt_img = None cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None cropped_gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/Bicubic/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/Bulat/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/Bulat/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/Bulat/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/Bulat/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/Bulat/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/Bulat/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, nb, nf, scale=4, zero_tail=False, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(3, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(3, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, 3, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) if zero_tail: nn.init.constant_(self.tail[-1].weight, 0) nn.init.constant_(self.tail[-1].bias, 0) def forward(self, x): f = self.head(x) f = self.body(f) f = self.tail(f) if self.scale == 1: x = f + x else: x = f + F.interpolate(x, scale_factor=self.scale) return x ================================================ FILE: codes/config/Bulat/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/Bulat/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/Bulat/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/Bulat/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/Bulat/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/Bulat/models/deg_sr_model.py ================================================ import logging from collections import OrderedDict import random import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from models.base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class DegSRModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["syn_lr", "syn_hr", "real_lr"] self.network_names = ["netSR", "netDeg", "netD1", "netD2"] self.networks = {} self.loss_names = [ "lr_adv", "lr_percep", "lr_color", "lr_tv", "sr_tv", "sr_pix", "sr_adv", "sr_percep" ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: # setup loss, optimizers, schedulers self.setup_train(opt["train"]) self.max_grad_norm = train_opt["max_grad_norm"] self.D_ratio = train_opt["D_ratio"] ## buffer self.fake_lr_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_hr_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.syn_hr = data["tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def forward(self): self.fake_real_lr = self.netDeg(self.syn_hr) self.fake_syn_hr = self.netSR(self.fake_real_lr) # self.fake_real_hr = self.netSR(self.real_lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() loss_G = 0 if self.losses.get("lr_adv"): self.set_requires_grad(["netD1"], False) g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["lr_adv"], self.real_lr, self.fake_real_lr ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["lr_adv"] * g1_adv_loss if self.losses.get("lr_percep"): lr_percep, lr_style = self.losses["lr_percep"](self.real_lr, self.fake_real_lr) loss_dict["lr_percep"] = lr_percep.item() if lr_style is not None: loss_dict["lr_style"] = lr_style.item() loss_G += self.loss_weights["sr_percep"] * lr_style loss_G += self.loss_weights["sr_percep"] * lr_percep if self.losses.get("lr_color"): lr_color = self.losses["lr_color"](self.fake_real_lr, self.syn_hr) loss_dict["lr_color"] = lr_color.item() loss_G += self.loss_weights["lr_color"] * lr_color if self.losses.get("sr_adv"): self.set_requires_grad(["netD2"], False) sr_adv = self.calculate_gan_loss_G( self.netD2, self.losses["sr_adv"], self.syn_hr, self.fake_syn_hr ) loss_dict["sr_adv"] = sr_adv.item() loss_G += self.loss_weights["sr_adv"] * sr_adv if self.losses.get("sr_pix"): sr_pix = self.losses["sr_pix"](self.fake_syn_hr, self.syn_hr) loss_dict["sr_pix"] = sr_pix.item() loss_G += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.syn_hr, self.fake_syn_hr) loss_dict["sr_percep"] = sr_percep.item() if sr_style is not None: loss_dict["sr_style"] = sr_style.item() loss_G += self.loss_weights["sr_percep"] * sr_style loss_G += self.loss_weights["sr_percep"] * sr_percep if self.losses.get("sr_tv"): sr_tv = self.losses["sr_tv"](self.fake_real_hr) loss_dict["sr_tv"] = sr_tv.item() loss_G = self.loss_weights["sr_tv"] * sr_tv self.set_optimizer(names=["netDeg", "netSR"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(names=["netDeg", "netSR"], norm=self.max_grad_norm) self.set_optimizer(names=["netDeg", "netSR"], operation="step") ## update D1, D2 loss_D = 0 if self.losses.get("lr_adv"): if step % self.D_ratio == 0: self.set_requires_grad(["netD1"], True) loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["lr_adv"], self.real_lr, self.fake_lr_buffer.choose(self.fake_real_lr.detach()) ) loss_dict["d1_adv"] = loss_d1.item() loss_d1 = self.loss_weights["lr_adv"] * loss_d1 self.set_optimizer(names=["netD1"], operation="zero_grad") loss_d1.backward() self.clip_grad_norm(["netD1"], norm=self.max_grad_norm) self.set_optimizer(names=["netD1"], operation="step") if self.losses.get("sr_adv"): if step % self.D_ratio == 0: self.set_requires_grad(["netD2"], True) loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["sr_adv"], self.syn_hr, self.fake_sr_buffer.choose(self.fake_syn_hr.detach()) ) loss_dict["d2_adv"] = loss_d2.item() loss_d2 = self.loss_weights["sr_adv"] * loss_d2 self.set_optimizer(names=["netD2"], operation="zero_grad") loss_d1.backward() self.clip_grad_norm(["netD2"], norm=self.max_grad_norm) self.set_optimizer(names=["netD2"], operation="step") self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, data): self.real_lr = data["src"].to(self.device) self.set_network_state(["netSR"], "eval") with torch.no_grad(): self.fake_real_hr = self.netSR(self.real_lr) self.set_network_state(["netSR"], "train") def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict class ShuffleBuffer(): """Random choose previous generated images or ones produced by the latest generators. :param buffer_size: the size of image buffer :type buffer_size: int """ def __init__(self, buffer_size): """Initialize the ImagePool class. :param buffer_size: the size of image buffer :type buffer_size: int """ self.buffer_size = buffer_size self.num_imgs = 0 self.images = [] def choose(self, images, prob=0.5): """Return an image from the pool. :param images: the latest generated images from the generator :type images: list :param prob: probability (0~1) of return previous images from buffer :type prob: float :return: Return images from the buffer :rtype: list """ if self.buffer_size == 0: return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.buffer_size: self.images.append(image) return_images.append(image) self.num_imgs += 1 else: p = random.uniform(0, 1) if p < prob: idx = random.randint(0, self.buffer_size - 1) stored_image = self.images[idx].clone() self.images[idx] = image return_images.append(stored_image) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: codes/config/Bulat/options/test/2017Track2.yml ================================================ #### general settings name: 2017Track2_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: log/2017Track2/models/latest_netG1.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2017Track2/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/Bulat/options/test/2018Track2.yml ================================================ #### general settings name: 2018Track2_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [1] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: log/2018Track2/models/latest_netG1.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track2/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/Bulat/options/test/2018Track4.yml ================================================ #### general settings name: 2018Track4_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: log/2018Track4/models/latest_netG1.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track4/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/Bulat/options/test/2020Track1.yml ================================================ #### general settings name: 2020Track1_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netG1: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: log/2020Track1/models/latest_netG1.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2020Track1/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/Bulat/options/train/psnr/2017Track2.yml ================================================ #### general settings name: 2017Track2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 16 losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_color: type: ColorLoss gauss_opt: ~ pool_opt: ksize: 4 loss_type: mse stride: 4 weight: 1.0 sr_pix: type: L1Loss weight: 1.0 optimizers: default: type: Adam lr: !!float 2e-4 netG1: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Bulat/options/train/psnr/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [3] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 16 losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_color: type: ColorLoss gauss_opt: ~ pool_opt: ksize: 4 loss_type: mse stride: 4 weight: 1.0 sr_pix: type: L1Loss weight: 1.0 optimizers: default: type: Adam lr: !!float 2e-4 netG1: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Bulat/options/train/psnr/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 16 losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_color: type: ColorLoss gauss_opt: ~ pool_opt: ksize: 4 loss_type: mse stride: 4 weight: 1.0 sr_pix: type: L1Loss weight: 1.0 optimizers: default: type: Adam lr: !!float 2e-4 netG1: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Bulat/options/train/psnr/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nf: 64 nb: 8 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 16 losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_color: type: ColorLoss gauss_opt: ~ pool_opt: ksize: 4 loss_type: mse stride: 4 weight: 1.0 sr_pix: type: L1Loss weight: 1.0 optimizers: default: type: Adam lr: !!float 2e-4 netG1: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Bulat/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/Bulat/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/CinGAN/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/CinGAN/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/CinGAN/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/CinGAN/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/CinGAN/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/CinGAN/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, nb, nf, scale=4, zero_tail=False, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(3, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(3, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, 3, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) if zero_tail: nn.init.constant_(self.tail[-1].weight, 0) nn.init.constant_(self.tail[-1].bias, 0) def forward(self, x): f = self.head(x) f = self.body(f) f = self.tail(f) if self.scale == 1: x = f + x else: x = f + F.interpolate(x, scale_factor=self.scale) return x ================================================ FILE: codes/config/CinGAN/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/CinGAN/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/CinGAN/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/CinGAN/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/CinGAN/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/CinGAN/models/cingan_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel from .trans_model import ShuffleBuffer logger = logging.getLogger("base") @MODEL_REGISTRY.register() class CinGANModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["syn_lr", "syn_hr", "real_lr"] self.network_names = ["netSR", "netG1", "netG2", "netG3", "netD1", "netD2"] self.networks = {} self.loss_names = [ "srd2_adv", "sr_tv", "srg3_cycle", "g1d1_adv", "g1g2_cycle", "lr_tv", ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(train_opt["train"]) self.max_grad_norm = train_opt["max_grad_norm"] # buffer self.fake_lr_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_hr_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.syn_lr = data["ref_src"].to(self.device) self.syn_hr = data["ref_tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def foward_trans(self): self.fake_syn_lr = self.netG1(self.real_lr) self.rec_real_lr = self.netG2(self.fake_syn_lr) def forward_sr(self): self.fake_syn_lr = self.netG1(self.real_lr) self.fake_real_hr = self.netSR(self.fake_syn_lr) self.rec_real_lr = self.netG3(self.fake_real_hr) def optimize_parameters(self, step): loss_dict = OrderedDict() # update trans ## update generators self.set_requires_grad(["netD1"], False) self.foward_trans() loss_G = 0 g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["g1d1_adv"], self.syn_lr, self.fake_syn_lr ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["g1d1_adv"] * g1_adv_loss if self.losses.get("lr_tv"): lr_tv_loss = self.losses["lr_tv"](self.fake_syn_lr) loss_dict["lr_tv"] = lr_tv_loss.item() loss_G += self.loss_weights["lr_tv"] * lr_tv_loss g1g2_cycle = self.losses["g1g2_cycle"](self.rec_real_lr, self.real_lr) loss_dict["g1g2_cycle"] = g1g2_cycle.item() loss_G += self.loss_weights["g1g2_cycle"] * g1g2_cycle self.set_optimizer(names=["netG1","netG2"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(["netG1","netG2"], norm=self.max_grad_norm) self.set_optimizer(names=["netG1", "netG2"], operation="step") ## update D self.set_requires_grad(["netD1"], True) loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["g1d1_adv"], self.syn_lr, self.fake_lr_buffer.choose(self.fake_syn_lr) ) loss_dict["d1_adv"] = loss_d1.item() loss_D = self.loss_weights["g1d1_adv"] * loss_d1 self.set_optimizer(["netD1"], "zero_grad") loss_D.backward() self.clip_grad_norm(["netD1"], self.max_grad_norm) self.set_optimizer(["netD1"], "step") # update sr self.set_requires_grad(["netD2"], False) self.forward_sr() loss_G = 0 srd2_adv_g = self.calculate_gan_loss_G( self.netD2, self.losses["srd2_adv"], self.syn_hr, self.fake_real_hr ) loss_dict["sr_adv"] = srd2_adv_g.item() loss_G += self.loss_weights["srd2_adv"] * srd2_adv_g if self.losses.get("sr_tv"): sr_tv_loss = self.losses["sr_tv"](self.fake_real_hr) loss_dict["sr_tv"] = sr_tv_loss.item() loss_G += self.loss_weights["sr_tv"] * sr_tv_loss srg3_cycle = self.losses["srg3_cycle"](self.rec_real_lr, self.real_lr) loss_dict["srg3_cycle"] = srg3_cycle.item() loss_G += self.loss_weights["srg3_cycle"] * srg3_cycle self.set_optimizer(names=["netG1", "netSR", "netG3"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(names=["netG1", "netSR", "netG3"], norm=self.max_grad_norm) self.set_optimizer(names=["netG1", "netSR", "netG3"], operation="step") ## update D1, D2 self.set_requires_grad(["netD2"], True) loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["srd2_adv"], self.syn_hr, self.fake_hr_buffer.choose(self.fake_real_hr.detach()) ) loss_dict["d1_adv"] = loss_d2.item() loss_D = self.loss_weights["srd2_adv"] * loss_d2 self.set_optimizer(names=["netD2"], operation="zero_grad") loss_D.backward() self.clip_grad_norm(["netD2"], self.max_grad_norm) self.set_optimizer(names=["netD2"], operation="step") self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data): self.real_lr = data["src"].to(self.device) self.set_network_state(["netSR", "netG1"], "eval") with torch.no_grad(): self.fake_syn_lr = self.netG1(self.real_lr) self.fake_real_hr = self.netSR(self.fake_syn_lr) self.set_network_state(["netSR", "netG1"], "train") def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/CinGAN/models/trans_model.py ================================================ import logging from collections import OrderedDict import random import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class TransModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["src", "tgt"] self.network_names = ["netG1", "netG2", "netD1"] self.networks = {} self.loss_names = [ "g1d1_adv", "g1_idt", "g1g2_cycle", "lr_tv" ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(train_opt["train"]) self.max_grad_norm = train_opt["max_grad_norm"] # buffer self.fake_tgt_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.src = data["src"].to(self.device) self.tgt = data["tgt"].to(self.device) def forward(self): self.fake_tgt = self.netG1(self.src) self.rec_src = self.netG2(self.fake_tgt) def optimize_parameters(self, step): loss_dict = OrderedDict() self.forward() loss_G = 0 # set D fixed self.set_requires_grad(["netD1"], False) g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["g1d1_adv"], self.tgt, self.fake_tgt ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["g1d1_adv"] * g1_adv_loss if self.losses.get("g1_idt"): self.tgt_idt = self.netG1(self.tgt) g1_idt = self.losses["g1_idt"](self.tgt, self.tgt_idt) loss_dict["g1_idt"] = g1_idt.item() loss_G += self.loss_weights["g1_idt"] * g1_idt if self.losses.get("lr_tv"): lr_tv = self.losses["lr_tv"](self.fake_tgt) loss_dict["lr_tv"] = lr_tv.item() loss_G += self.loss_weights["lr_tv"] * lr_tv g1g2_cycle = self.losses["g1g2_cycle"](self.rec_src, self.src) loss_dict["g1g2_cycle"] = g1g2_cycle.item() loss_G += self.loss_weights["g1g2_cycle"] * g1g2_cycle self.set_optimizer(names=["netG1", "netG2"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(names=["netG1", "netG2"], norm=self.max_grad_norm) self.set_optimizer(names=["netG1", "netG2"], operation="step") ## update D1, D2 self.set_requires_grad(["netD1"], True) loss_D = 0 loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["g1d1_adv"], self.tgt, self.fake_tgt_buffer.choose(self.fake_tgt.detach()) ) loss_dict["d1_adv"] = loss_d1.item() loss_D += loss_d1 self.set_optimizer(names=["netD1"], operation="zero_grad") loss_D.backward() self.clip_grad_norm(names=["netG1"], norm=self.max_grad_norm) self.set_optimizer(names=["netD1"], operation="step") self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data): self.src = data["src"].to(self.device) self.netG1.eval() with torch.no_grad(): self.fake_tgt = self.netG1(self.src) self.netG1.train() def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.src.detach()[0].float().cpu() out_dict["sr"] = self.fake_tgt.detach()[0].float().cpu() return out_dict class ShuffleBuffer(): """Random choose previous generated images or ones produced by the latest generators. :param buffer_size: the size of image buffer :type buffer_size: int """ def __init__(self, buffer_size): """Initialize the ImagePool class. :param buffer_size: the size of image buffer :type buffer_size: int """ self.buffer_size = buffer_size self.num_imgs = 0 self.images = [] def choose(self, images, prob=0.5): """Return an image from the pool. :param images: the latest generated images from the generator :type images: list :param prob: probability (0~1) of return previous images from buffer :type prob: float :return: Return images from the buffer :rtype: list """ return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.buffer_size: self.images.append(image) return_images.append(image) self.num_imgs += 1 else: p = random.uniform(0, 1) if p < prob: idx = random.randint(0, self.buffer_size - 1) stored_image = self.images[idx].clone() self.images[idx] = image return_images.append(stored_image) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: codes/config/CinGAN/options/test/sr/2017Track1.yml ================================================ #### general settings name: 2017Track1 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CinGAN2017Track1/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/CinGAN2017Track1/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/CinGAN/options/test/sr/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CinGAN2018Track2/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/CinGAN2018Track2/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/CinGAN/options/test/sr/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb test4: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CinGAN2018Track4/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/CinGAN2018Track4/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/CinGAN/options/test/sr/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CinGAN2020Track1/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/CinGAN2020Track1/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/CinGAN/options/train/sr/2017Track2.yml ================================================ #### general settings name: CinGAN2017Track2 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: log/Trans2017Track1/models/latest_netD2.pth #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2017Track1/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: log/Trans2017Track1/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2017Track1/models/latest_netG2.pth strict_load: true netG3: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 0.25 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 buffer_size: 16 losses: srd2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.5 sr_tv: type: TVLoss penealty: MSELoss weight: 2 srg3_cycle: type: L1Loss weight: 10 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netG3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/sr/2018Track2.yml ================================================ #### general settings name: CinGAN2018Track2 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [6] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: log/Trans2018Track2/models/latest_netD2.pth #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2018Track2/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: log/Trans2018Track2/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2018Track2/models/latest_netG2.pth strict_load: true netG3: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 0.25 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 buffer_size: 16 losses: srd2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.5 sr_tv: type: TVLoss penealty: MSELoss weight: 2 srg3_cycle: type: L1Loss weight: 10 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netG3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/sr/2018Track4.yml ================================================ #### general settings name: CinGAN2018Track4 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures # netSR: # which_network: RRDBNet # setting: # in_nc: 3 # out_nc: 3 # nf: 64 # nb: 23 # upscale: 4 # pretrain: # path: ../../../checkpoints/ESRGAN/RRDB_PSNR_x4.pth # strict_load: true networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: log/Trans2018Track4/models/latest_netD2.pth #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2018Track4/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: log/Trans2018Track4/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2018Track4/models/latest_netG2.pth strict_load: true netG3: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 0.25 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 buffer_size: 16 losses: srd2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.5 sr_tv: type: TVLoss penealty: MSELoss weight: 2 srg3_cycle: type: L1Loss weight: 10 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netG3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/sr/2020Track1.yml ================================================ #### general settings name: CinGAN2020Track1 use_tb_logger: false model: CinGANModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures # netSR: # which_network: RRDBNet # setting: # in_nc: 3 # out_nc: 3 # nf: 64 # nb: 23 # upscale: 4 # pretrain: # path: ../../../checkpoints/ESRGAN/RRDB_PSNR_x4.pth # strict_load: true networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: log/Trans2020Track1/models/100000_netD2.pth #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2020Track1/models/100000_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: log/Trans2020Track1/models/100000_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/Trans2020Track1/models/100000_netG2.pth strict_load: true netG3: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 0.25 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 buffer_size: 16 losses: srd2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.5 sr_tv: type: TVLoss penealty: MSELoss weight: 2 srg3_cycle: type: L1Loss weight: 10 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netG3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/trans/2017Track2.yml ================================================ #### general settings name: Trans2017Track2 use_tb_logger: false model: TransModel scale: 1 gpu_ids: [2] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ niter: 100000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/trans/2018Track2.yml ================================================ #### general settings name: Trans2018Track2 use_tb_logger: false model: TransModel scale: 1 gpu_ids: [3] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ niter: 100000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/trans/2018Track4.yml ================================================ #### general settings name: Trans2018Track4 use_tb_logger: false model: TransModel scale: 1 gpu_ids: [4] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ niter: 100000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/options/train/trans/2020Track1.yml ================================================ #### general settings name: Trans2020Track1 use_tb_logger: false model: TransModel scale: 1 gpu_ids: [0] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 lr_tv: type: TVLoss penealty: MSELoss weight: 0.5 g1_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ niter: 100000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CinGAN/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/CinGAN/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/CycleSR/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/CycleSR/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/CycleSR/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/CycleSR/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/CycleSR/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/CycleSR/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY from .edsr import default_conv, BasicBlock, ResBlock @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, nb, nf, scale=4, zero_tail=False, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(3, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(3, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, 3, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) if zero_tail: nn.init.constant_(self.tail[-1].weight, 0) nn.init.constant_(self.tail[-1].bias, 0) def forward(self, x): f = self.head(x) f = self.body(f) f = self.tail(f) if self.scale == 1: x = f + x else: x = f + F.interpolate(x, scale_factor=self.scale) return x ================================================ FILE: codes/config/CycleSR/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/CycleSR/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/CycleSR/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/CycleSR/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/CycleSR/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/CycleSR/models/cyclegan_model.py ================================================ import logging from collections import OrderedDict import random import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class CycleGANModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["src", "tgt"] self.network_names = ["netG1", "netG2", "netD1", "netD2"] self.networks = {} self.loss_names = [ "g1d1_adv", "g2d2_adv", "g1_idt", "g2_idt", "g1g2_cycle", "g2g1_cycle", ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(train_opt) self.max_grad_norm = train_opt["max_grad_norm"] # buffer self.fake_src_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_tgt_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.src = data["src"].to(self.device) self.tgt = data["tgt"].to(self.device) def forward(self): self.fake_tgt = self.netG1(self.src) self.rec_src = self.netG2(self.fake_tgt) self.fake_src = self.netG2(self.tgt) self.rec_tgt = self.netG1(self.fake_src) def optimize_parameters(self, step): loss_dict = OrderedDict() self.forward() loss_G = 0 # set D fixed self.set_requires_grad(["netD1", "netD2"], False) g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["g1d1_adv"], self.tgt, self.fake_tgt ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["g1d1_adv"] * g1_adv_loss g2_adv_loss = self.calculate_gan_loss_G( self.netD2, self.losses["g2d2_adv"], self.src, self.fake_src ) loss_dict["g2_adv"] = g2_adv_loss.item() loss_G += self.loss_weights["g2d2_adv"] * g2_adv_loss if self.losses.get("g1_idt"): self.tgt_idt = self.netG1(self.tgt) g1_idt = self.losses["g1_idt"](self.tgt, self.tgt_idt) loss_dict["g1_idt"] = g1_idt.item() loss_G += self.loss_weights["g1_idt"] * g1_idt if self.losses.get("g2_idt"): self.src_idt = self.netG2(self.src) g2_idt = self.losses["g2_idt"](self.src, self.src_idt) loss_dict["g2_idt"] = g2_idt.item() loss_G += self.loss_weights["g2_idt"] * g2_idt g1g2_cycle = self.losses["g1g2_cycle"](self.rec_src, self.src) loss_dict["g1g2_cycle"] = g1g2_cycle.item() loss_G += self.loss_weights["g1g2_cycle"] * g1g2_cycle g2g1_cycle = self.losses["g2g1_cycle"](self.rec_tgt, self.tgt) loss_dict["g2g1_cycle"] = g2g1_cycle.item() loss_G += self.loss_weights["g2g1_cycle"] * g2g1_cycle self.set_optimizer(names=["netG1", "netG2"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(names=["netG1", "netG2"], norm=self.max_grad_norm) self.set_optimizer(names=["netG1", "netG2"], operation="step") ## update D1, D2 self.set_requires_grad(["netD1", "netD2"], True) loss_D = 0 loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["g1d1_adv"], self.tgt, self.fake_tgt_buffer.choose(self.fake_tgt.detach()) ) loss_dict["d1_adv"] = loss_d1.item() loss_D += loss_d1 loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["g2d2_adv"], self.src, self.fake_src_buffer.choose(self.fake_src) ) loss_dict["d2_adv"] = loss_d2.item() loss_D += loss_d2 self.set_optimizer(names=["netD1", "netD2"], operation="zero_grad") loss_D.backward() self.clip_grad_norm(names=["netD1","netD2"], norm=self.max_grad_norm) self.set_optimizer(names=["netD1", "netD2"], operation="step") self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, data): self.src = data["src"].to(self.device) self.netG1.eval() with torch.no_grad(): self.fake_tgt = self.netG1(self.src) self.netG1.train() def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.src.detach()[0].float().cpu() out_dict["sr"] = self.fake_tgt.detach()[0].float().cpu() return out_dict class ShuffleBuffer(): """Random choose previous generated images or ones produced by the latest generators. :param buffer_size: the size of image buffer :type buffer_size: int """ def __init__(self, buffer_size): """Initialize the ImagePool class. :param buffer_size: the size of image buffer :type buffer_size: int """ self.buffer_size = buffer_size self.num_imgs = 0 self.images = [] def choose(self, images, prob=0.5): """Return an image from the pool. :param images: the latest generated images from the generator :type images: list :param prob: probability (0~1) of return previous images from buffer :type prob: float :return: Return images from the buffer :rtype: list """ return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.buffer_size: self.images.append(image) return_images.append(image) self.num_imgs += 1 else: p = random.uniform(0, 1) if p < prob: idx = random.randint(0, self.buffer_size - 1) stored_image = self.images[idx].clone() self.images[idx] = image return_images.append(stored_image) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: codes/config/CycleSR/models/cyclesr_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel from .cyclegan_model import ShuffleBuffer logger = logging.getLogger("base") @MODEL_REGISTRY.register() class CycleSRModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["syn_lr", "syn_hr", "real_lr"] self.network_names = ["netSR", "netG1", "netG2", "netD1", "netD2", "netD3"] self.networks = {} self.loss_names = [ "sr_adv", "sr_pix", "sr_pix_trans", "sr_percep", "g1_d1_adv", "g2_d2_adv", "g1_idt", "g2_idt", "g1g2_cycle", "g2g1_cycle", ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(train_opt) self.max_grad_norm = train_opt["max_grad_norm"] # buffer self.fake_src_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_tgt_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.syn_lr = data["ref_src"].to(self.device) self.syn_hr = data["ref_tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def forward_trans(self): self.fake_real_lr = self.netG1(self.syn_lr) self.fake_syn_hr = self.netSR(self.fake_real_lr) self.rec_syn_lr = self.netG2(self.fake_real_lr) self.fake_syn_lr = self.netG2(self.real_lr) self.rec_real_lr = self.netG1(self.fake_syn_lr) def forward_sr(self): self.fake_syn_hr = self.netSR(self.fake_real_lr.detach()) if self.losses.get("sr_adv"): self.fake_real_hr = self.netSR(self.real_lr) def optimize_trans_models(self, step, loss_dict): # set D fixed self.set_requires_grad(["netD1", "netD2", "netSR"], False) self.forward_trans() loss_trans = 0 g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["g1_d1_adv"], self.real_lr, self.fake_real_lr ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_trans += self.loss_weights["g1_d1_adv"] * g1_adv_loss g2_adv_loss = self.calculate_gan_loss_G( self.netD2, self.losses["g2_d2_adv"], self.syn_lr, self.fake_syn_lr ) loss_dict["g2_adv"] = g2_adv_loss.item() loss_trans += self.loss_weights["g2_d2_adv"] * g2_adv_loss g1g2_cycle = self.losses["g1g2_cycle"](self.rec_syn_lr, self.syn_lr) loss_dict["g1g2_cycle"] = g1g2_cycle.item() loss_trans += self.loss_weights["g1g2_cycle"] * g1g2_cycle if self.losses.get("g1_idt"): self.real_lr_idt = self.netG1(self.real_lr) g1_idt = self.losses["g1_idt"](self.real_lr, self.real_lr_idt) loss_dict["g1_idt"] = g1_idt.item() loss_trans += self.loss_weights["g1_idt"] * g1_idt if self.losses.get("g2_idt"): self.syn_lr_idt = self.netG2(self.syn_lr) g2_idt = self.losses["g2_idt"](self.syn_lr, self.syn_lr_idt) loss_dict["g2_idt"] = g2_idt.item() loss_trans += self.loss_weights["g2_idt"] * g2_idt g2g1_cycle = self.losses["g2g1_cycle"](self.rec_real_lr, self.real_lr) loss_dict["g2g1_cycle"] = g2g1_cycle.item() loss_trans += self.loss_weights["g2g1_cycle"] * g2g1_cycle loss_sr_pix = self.losses["sr_pix_trans"](self.fake_syn_hr, self.syn_hr) loss_dict["sr_pix_trans"] = loss_sr_pix.item() loss_trans += self.loss_weights["sr_pix_trans"] * loss_sr_pix self.set_optimizer(names=["netG1", "netG2"], operation="zero_grad") loss_trans.backward() self.clip_grad_norm(["netG1", "netG2"], self.max_grad_norm) self.set_optimizer(names=["netG1", "netG2"], operation="step") ## update D1, D2 self.set_requires_grad(["netD1", "netD2"], True) loss_d1d2 = 0 loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["g1_d1_adv"], self.real_lr, self.fake_real_lr ) loss_dict["d1_adv"] = loss_d1.item() loss_d1d2 += loss_d1 loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["g2_d2_adv"], self.syn_lr, self.fake_syn_lr ) loss_dict["d2_adv"] = loss_d2.item() loss_d1d2 += loss_d2 self.set_optimizer(names=["netD1", "netD2"], operation="zero_grad") loss_d1d2.backward() self.clip_grad_norm(["netD1", "netD2"], self.max_grad_norm) self.set_optimizer(names=["netD1", "netD2"], operation="step") return loss_dict def optimize_sr_models(self, step, loss_dict): self.set_requires_grad(["netSR"], True) self.forward_sr() l_sr = 0 sr_pix = self.losses["sr_pix"](self.syn_hr, self.fake_syn_hr) loss_dict["sr_pix"] = sr_pix.item() l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD3"], False) sr_adv_g = self.calculate_gan_loss_G( self.netD3, self.losses["sr_adv"], self.syn_hr, self.fake_syn_hr ) loss_dict["sr_adv_g"] = sr_adv_g.item() l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"]( self.syn_hr, self.fake_syn_hr ) loss_dict["sr_percep"] = sr_percep.item() if sr_style is not None: loss_dict["sr_style"] = sr_style.item() l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.clip_grad_norm(["netSR"], self.max_grad_norm) self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD3"], True) sr_adv_d = self.calculate_gan_loss_D( self.netD3, self.losses["sr_adv"], self.syn_hr, self.fake_syn_hr ) loss_dict["sr_adv_d"] = sr_adv_d.item() loss_D = self.loss_weights["sr_adv"] * sr_adv_d self.optimizers["netD3"].zero_grad() loss_D.backward() self.clip_grad_norm(["netD3"], self.max_grad_norm) self.optimizers["netD3"].step() return loss_dict def optimize_parameters(self, step): loss_dict = OrderedDict() loss_dict = self.optimize_trans_models(step, loss_dict) loss_dict = self.optimize_sr_models(step, loss_dict) for k, v in loss_dict.items(): self.log_dict[k] = v def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, data): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): self.fake_real_hr = self.netSR(self.real_lr) self.netSR.train() def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/CycleSR/options/test/sr/2017Track1.yml ================================================ #### general settings name: 2017Track1 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CycleSR2017Track1/models/200000_netSR.pth strict_load: true ================================================ FILE: codes/config/CycleSR/options/test/sr/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/validx4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CycleSR2018Track2/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/CycleSR/options/test/sr/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [3] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/validx4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb test4: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CycleSR2018Track4/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/CycleSR/options/test/sr/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/validx4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CycleSR2020Track1/models/200000_netSR.pth strict_load: true ================================================ FILE: codes/config/CycleSR/options/test/sr/2020Track1_percep.yml ================================================ #### general settings name: 2020Track1_percep use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/validx4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/CycleSR2020Track1_percep/models/200000_netSR.pth strict_load: true ================================================ FILE: codes/config/CycleSR/options/train/sr/psnr/2017Track2.yml ================================================ #### general settings name: CycleSR2017Track1 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [3] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [1, 1] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true # netD3: # which_network: PatchGANDiscriminator # setting: # in_c: 3 # nf: 64 # nb: 3 # stride: 2 # pretrain: # path: ~ # strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2017Track1/models/190000_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2017Track1/models/190000_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2017Track1/models/190000_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2017Track1/models/190000_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50.0 buffer_size: 16 losses: # sr_adv: # type: GANLoss # gan_type: lsgan # real_label_val: 1.0 # fake_label_val: 0.0 # weight: !!float 0.0 sr_pix_trans: type: MSELoss weight: 1000.0 sr_pix: type: MSELoss weight: 1.0 # sr_percep: # type: PerceptualLoss # layer_weights: # 'conv5_4': 1 # before relu # vgg_type: vgg19 # use_input_norm: true # range_norm: false # perceptual_weight: 1.0 # style_weight: 0 # criterion: l1 # weight: !!float 0.0 g1_d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2_d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5 g2_idt: type: L1Loss weight: 5 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ # netD3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/sr/psnr/2018Track2.yml ================================================ #### general settings name: CycleSR2017Track1 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures # netSR: # which_network: RRDBNet # setting: # in_nc: 3 # out_nc: 3 # nf: 64 # nb: 23 # upscale: 4 # pretrain: # path: ../../../checkpoints/ESRGAN/RRDB_PSNR_x4.pth # strict_load: true networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true # netD3: # which_network: PatchGANDiscriminator # setting: # in_c: 3 # nf: 64 # nb: 3 # stride: 2 # pretrain: # path: ~ # strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2018Track2/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2018Track2/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2018Track2/models/latest_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2018Track2/models/latest_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50.0 buffer_size: 16 losses: # sr_adv: # type: GANLoss # gan_type: lsgan # real_label_val: 1.0 # fake_label_val: 0.0 # weight: !!float 0.0 sr_pix_trans: type: MSELoss weight: 1000.0 sr_pix: type: MSELoss weight: 1.0 # sr_percep: # type: PerceptualLoss # layer_weights: # 'conv5_4': 1 # before relu # vgg_type: vgg19 # use_input_norm: true # range_norm: false # perceptual_weight: 1.0 # style_weight: 0 # criterion: l1 # weight: !!float 0.0 g1_d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2_d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5 g2_idt: type: L1Loss weight: 5 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ # netD3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/sr/psnr/2018Track4.yml ================================================ #### general settings name: CycleSR2018Track4 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures # netSR: # which_network: RRDBNet # setting: # in_nc: 3 # out_nc: 3 # nf: 64 # nb: 23 # upscale: 4 # pretrain: # path: ../../../checkpoints/ESRGAN/RRDB_PSNR_x4.pth # strict_load: true networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true # netD3: # which_network: PatchGANDiscriminator # setting: # in_c: 3 # nf: 64 # nb: 3 # stride: 2 # pretrain: # path: ~ # strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2018Track4/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2018Track4/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2018Track4/models/latest_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2018Track4/models/latest_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50.0 buffer_size: 16 losses: # sr_adv: # type: GANLoss # gan_type: lsgan # real_label_val: 1.0 # fake_label_val: 0.0 # weight: !!float 0.0 sr_pix_trans: type: MSELoss weight: 1000.0 sr_pix: type: MSELoss weight: 1.0 # sr_percep: # type: PerceptualLoss # layer_weights: # 'conv5_4': 1 # before relu # vgg_type: vgg19 # use_input_norm: true # range_norm: false # perceptual_weight: 1.0 # style_weight: 0 # criterion: l1 # weight: !!float 0.0 g1_d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2_d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5 g2_idt: type: L1Loss weight: 5 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ # netD3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/sr/psnr/2020Track1.yml ================================================ #### general settings name: CycleSR2020Track1 use_tb_logger: false model: CycleSRModel scale: 4 gpu_ids: [4] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures # netSR: # which_network: RRDBNet # setting: # in_nc: 3 # out_nc: 3 # nf: 64 # nb: 23 # upscale: 4 # pretrain: # path: ../../../checkpoints/ESRGAN/RRDB_PSNR_x4.pth # strict_load: true networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true # netD3: # which_network: PatchGANDiscriminator # setting: # in_c: 3 # nf: 64 # nb: 3 # stride: 2 # pretrain: # path: ~ # strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2020Track1/models/latest_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2020Track1/models/latest_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 pretrain: path: log/Trans2020Track1/models/latest_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/Trans2020Track1/models/latest_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50.0 buffer_size: 16 losses: # sr_adv: # type: GANLoss # gan_type: lsgan # real_label_val: 1.0 # fake_label_val: 0.0 # weight: !!float 0.0 sr_pix_trans: type: MSELoss weight: 1000.0 sr_pix: type: MSELoss weight: 1.0 # sr_percep: # type: PerceptualLoss # layer_weights: # 'conv5_4': 1 # before relu # vgg_type: vgg19 # use_input_norm: true # range_norm: false # perceptual_weight: 1.0 # style_weight: 0 # criterion: l1 # weight: !!float 0.0 g1_d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2_d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5 g2_idt: type: L1Loss weight: 5 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.5, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ # netD3: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/trans/2017Track2.yml ================================================ #### general settings name: Trans2017Track1 use_tb_logger: false model: CycleGANModel scale: 1 gpu_ids: [3] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [1, 1] dataroot_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5.0 g2_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ netD2: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/trans/2018Track2.yml ================================================ #### general settings name: Trans2018Track2 use_tb_logger: false model: CycleGANModel scale: 1 gpu_ids: [0] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mild.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5.0 g2_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ netD2: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/trans/2018Track4.yml ================================================ #### general settings name: Trans2018Track4 use_tb_logger: false model: CycleGANModel scale: 1 gpu_ids: [1] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/track4/valid_wild.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5.0 g2_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ netD2: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/options/train/trans/2020Track1.yml ================================================ #### general settings name: Trans2020Track1 use_tb_logger: false model: CycleGANModel scale: 1 gpu_ids: [1] metrics: [psnr, ssim] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 32 src_size: 32 use_flip: true use_rot: true val: name: DIV2K mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/DIV2K_valid/BicLR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb #### network structures networks: netG1: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 8 scale: 1 zero_tail: true pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ buffer_size: 16 max_grad_norm: 50 losses: g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1_idt: type: L1Loss weight: 5.0 g2_idt: type: L1Loss weight: 5.0 g1g2_cycle: type: L1Loss weight: 10.0 g2g1_cycle: type: L1Loss weight: 10.0 optimizers: default: type: Adam lr: !!float 2e-4 betas: [0.5, 0.999] netG1: ~ netG2: ~ netD1: ~ netD2: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/CycleSR/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/CycleSR/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0: nc = self.kernel_opt["nc"] if self.kernel_opt["spatial"]: zk = torch.randn(B, nc, h, w).to(inp.device) else: zk = torch.randn(B, nc, 1, 1).to(inp.device) inp_k = torch.cat([inp_k, zk], 1) else: nc = self.kernel_opt["nc"] if self.kernel_opt["spatial"]: inp_k = torch.randn(B, nc, h, w).to(inp.device) else: inp_k = torch.randn(B, nc, 1, 1).to(inp.device) ksize = self.kernel_opt["ksize"] kernel = self.deg_kernel(inp_k).view(B, 1, ksize**2, *inp_k.shape[2:]) x = inp.view(B*C, 1, H, W) x = F.unfold( self.pad(x), kernel_size=ksize, stride=self.scale, padding=0 ).view(B, C, ksize**2, h, w) x = torch.mul(x, kernel).sum(2).view(B, C, h, w) kernel = kernel.view(B, ksize**2, *inp_k.shape[2:]) else: x = F.interpolate(inp, scale_factor=1/self.scale, mode="bicubic", align_corners=False) kernel = None # noise if self.noise_opt is not None: if self.noise_opt["mix"]: # inp_n = x.detach() inp_n = F.interpolate(inp, scale_factor=1/self.scale, mode="bicubic", align_corners=False) if self.noise_opt["nc"] > 0: nc = self.noise_opt["nc"] zn = torch.randn(B, nc, h, w).to(inp.device) inp_n = torch.cat([inp_n, zn], 1) else: nc = self.noise_opt["nc"] inp_n = torch.randn(B, nc, h, w).to(inp.device) noise = self.deg_noise(inp_n) x = x + noise else: noise = None return x, kernel, noise ================================================ FILE: codes/config/DSGANSR/archs/discriminator.py ================================================ import torch import torch.nn as nn import torchvision import functools from utils.registry import ARCH_REGISTRY @ARCH_REGISTRY.register() class DiscriminatorVGG128(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512 * 4 * 4, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class DiscriminatorVGG32(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class PatchGANDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, in_c, nf, nb, stride=1, norm_layer=nn.InstanceNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 3 padw = 1 sequence = [ nn.Conv2d(in_c, nf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, nb): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=stride, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult nf_mult = min(2 ** nb, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d(nf * nf_mult, nf, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) ================================================ FILE: codes/config/DSGANSR/archs/edsr.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, ): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer @ARCH_REGISTRY.register() class EDSR(nn.Module): def __init__(self, nb, nf, res_scale=0.1, upscale=4, conv=default_conv): super(EDSR, self).__init__() n_resblocks = nb n_feats = nf kernel_size = 3 scale = upscale act = nn.ReLU(True) # url_name = 'r{}f{}x{}'.format(nb, nf, upscale) # if url_name in url: # self.url = url[url_name] # else: # self.url = None self.sub_mean = MeanShift(255.0, sign=-1) self.add_mean = MeanShift(255.0, sign=1) # define head module m_head = [conv(3, n_feats, kernel_size)] # define body module m_body = [ ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x * 255.0) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) / 255.0 return x ================================================ FILE: codes/config/DSGANSR/archs/loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import lpips as lp from utils.registry import LOSS_REGISTRY from .vgg import VGGFeatureExtractor @LOSS_REGISTRY.register() class ColorLoss(nn.Module): def __init__(self, ksize=5, sigma=None, stride=1, recursion=1, loss_type="l1"): super().__init__() self.stride = stride self.ksize = ksize self.recursion = recursion self.loss_type = loss_type if sigma is None: sigma = ksize / 6 ax = torch.arange(0, ksize) - (ksize - 1) / 2 xx, yy = torch.meshgrid(ax, ax) dis = (xx ** 2 + yy ** 2) dis = torch.exp(-dis / 2 / sigma ** 2) dis = dis / dis.sum() weight = dis.view(1, 1, ksize, ksize).repeat(3, 1, 1, 1) self.register_buffer("weight", weight) def forward(self, src, tgt): for i in range(self.recursion): tgt = F.conv2d(tgt, self.weight, stride=self.stride, padding=self.ksize//2, groups=3) if self.loss_type == "l1": loss = F.l1_loss(src, tgt) elif self.loss_type == "mse": loss = F.mse_loss(src, tgt) return loss @LOSS_REGISTRY.register() class GaussGuided(nn.Module): def __init__(self, ksize, sigma): super().__init__() ax = torch.arange(0, ksize) - ksize//2 xx, yy = torch.meshgrid(ax, ax) dis = (xx ** 2 + yy ** 2) dis = torch.exp(-dis / sigma ** 2) dis = dis / dis.sum() self.register_buffer("gauss", dis.view(1, ksize**2, 1, 1)) def forward(self, kernel): return F.mse_loss(self.gauss, kernel) @LOSS_REGISTRY.register() class PerceptualLossLPIPS(nn.Module): def __init__(self, net="alex", normalize=True): super().__init__() self.fn = lp.LPIPS(net=net, spatial=True) for p in self.fn.parameters(): p.requires_grad = False self.normalize = normalize def forward(self, res, ref): return self.fn(res, ref, normalize=self.normalize).mean(), None @LOSS_REGISTRY.register() class MSELoss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.mse_loss(res, ref) @LOSS_REGISTRY.register() class L1Loss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.l1_loss(res, ref) @LOSS_REGISTRY.register() class GANLoss(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == "vanilla": self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == "lsgan": self.loss = nn.MSELoss() elif self.gan_type == "wgan": self.loss = self._wgan_loss elif self.gan_type == "wgan_softplus": self.loss = self._wgan_softplus_loss elif self.gan_type == "hinge": self.loss = nn.ReLU() else: raise NotImplementedError(f"GAN type {self.gan_type} is not implemented.") def _wgan_loss(self, input, target): """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean() def _wgan_softplus_loss(self, input, target): """wgan loss with soft plus. softplus is a smooth approximation to the ReLU function. In StyleGAN2, it is called: Logistic loss for discriminator; Non-saturating loss for generator. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return F.softplus(-input).mean() if target else F.softplus(input).mean() def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, return Tensor. """ if self.gan_type in ["wgan", "wgan_softplus"]: return target_is_real target_val = self.real_label_val if target_is_real else self.fake_label_val return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == "hinge": if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) return loss @LOSS_REGISTRY.register() class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. Args: layer_weights (dict): The weight for each layer of vgg feature. Here is an example: {'conv5_4': 1.}, which means the conv5_4 feature layer (before relu5_4) will be extracted with weight 1.0 in calculting losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/DSGANSR/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/DSGANSR/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/DSGANSR/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/DSGANSR/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/DSGANSR/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/DSGANSR/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY from .edsr import default_conv, BasicBlock, ResBlock, Upsampler @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, nb, nf, scale=4, zero_tail=False, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(3, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(3, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, 3, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) if zero_tail: nn.init.constant_(self.tail[-1].weight, 0) nn.init.constant_(self.tail[-1].bias, 0) def forward(self, x): f = self.head(x) f = self.body(f) f = self.tail(f) if self.scale == 1: x = f + x else: x = f + F.interpolate(x, scale_factor=self.scale) return x ================================================ FILE: codes/config/DSGANSR/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/DSGANSR/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/DSGANSR/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/DSGANSR/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/DSGANSR/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/DSGANSR/models/deg_sr_model.py ================================================ import logging from collections import OrderedDict import random import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") class Quant(torch.autograd.Function): @staticmethod def forward(ctx, input): output = torch.clamp(input, 0, 1) output = (output * 255.).round() / 255. return output @staticmethod def backward(ctx, grad_output): return grad_output class Quantization(nn.Module): def __init__(self): super(Quantization, self).__init__() def forward(self, input): return Quant.apply(input) @MODEL_REGISTRY.register() class DegSRModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["src", "tgt"] self.network_names = ["netSR", "netDeg", "netD1", "netD2"] self.networks = {} self.loss_names = [ "lr_adv", "lr_percep", "sr_adv", "sr_pix_trans", "sr_pix_sr", "sr_percep", "color" ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(train_opt) self.max_grad_norm = train_opt["max_grad_norm"] self.quant = Quantization() self.D_ratio = train_opt["D_ratio"] self.optim_deg = train_opt["optim_deg"] self.optim_sr = train_opt["optim_sr"] ## buffer self.fake_lr_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_hr_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.syn_hr = data["tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def encoder_forward(self): self.fake_real_lr = self.netDeg(self.syn_hr) def decoder_forward(self): if not self.optim_deg: self.fake_real_lr = self.netDeg(self.syn_hr) self.fake_real_lr_quant = self.quant(self.fake_real_lr) self.syn_sr = self.netSR(self.fake_real_lr_quant.detach()) def optimize_trans_models(self, loss_dict, step): self.set_requires_grad(["netSR"], False) self.encoder_forward() loss_G = 0 if self.losses.get("lr_adv"): self.set_requires_grad(["netD1"], False) g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["lr_adv"], self.real_lr, self.fake_real_lr ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["lr_adv"] * g1_adv_loss if self.losses.get("lr_percep"): lr_percep, lr_style = self.losses["lr_percep"]( self.real_lr, self.fake_real_lr ) loss_dict["lr_percep"] = lr_percep.item() if lr_style is not None: loss_dict["lr_style"] = lr_style.item() loss_G += self.loss_weights["lr_percep"] * lr_style loss_G += self.loss_weights["lr_percep"] * lr_percep if self.losses.get("color"): color = self.losses["color"]( self.fake_real_lr, self.syn_hr ) loss_dict["color"] = color.item() loss_G += self.loss_weights["color"] * color self.set_optimizer(names=["netDeg"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(["netDeg"], self.max_grad_norm) self.set_optimizer(names=["netDeg"], operation="step") self.update_learning_rate(["netDeg"], step) ## update D if self.losses.get("lr_adv"): if step % self.D_ratio == 0: self.set_requires_grad(["netD1"], True) loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["lr_adv"], self.real_lr, self.fake_lr_buffer.choose(self.fake_real_lr) ) loss_dict["d1_adv"] = loss_d1.item() loss_D = self.loss_weights["lr_adv"] * loss_d1 self.optimizers["netD1"].zero_grad() loss_D.backward() self.clip_grad_norm(["netD1"], self.max_grad_norm) self.optimizers["netD1"].step() self.update_learning_rate(["netD1"], step) return loss_dict def optimize_sr_models(self, loss_dict, step): self.set_requires_grad(["netSR"], True) self.set_requires_grad(["netDeg"], False) self.decoder_forward() loss_G = 0 if self.losses.get("sr_adv"): self.set_requires_grad(["netD2"], False) sr_adv_loss = self.calculate_gan_loss_G( self.netD2, self.losses["sr_adv"], self.syn_hr, self.syn_sr ) loss_dict["sr_adv"] = sr_adv_loss.item() loss_G += self.loss_weights["sr_adv"] * sr_adv_loss if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"]( self.syn_hr, self.syn_sr ) loss_dict["sr_percep"] = sr_percep.item() if sr_style is not None: loss_dict["sr_style"] = sr_style.item() loss_G += self.loss_weights["sr_percep"] * sr_style loss_G += self.loss_weights["sr_percep"] * sr_percep if self.losses.get("sr_pix_sr"): sr_pix = self.losses["sr_pix_sr"](self.syn_hr, self.syn_sr) loss_dict["sr_pix_sr"] = sr_pix.item() loss_G += self.loss_weights["sr_pix_sr"] * sr_pix self.set_optimizer(names=["netSR"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(["netSR"], self.max_grad_norm) self.set_optimizer(names=["netSR"], operation="step") self.update_learning_rate(["netSR"], step) ## update D2 if self.losses.get("sr_adv"): if step % self.D_ratio == 0: self.set_requires_grad(["netD2"], True) loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["sr_adv"], self.syn_hr, self.fake_hr_buffer.choose(self.syn_sr) ) loss_dict["d2_adv"] = loss_d2.item() loss_D = self.loss_weights["sr_adv"] * loss_d2 self.optimizers["netD2"].zero_grad() loss_D.backward() self.clip_grad_norm(["netD2"], self.max_grad_norm) self.optimizers["netD2"].step() self.update_learning_rate(["netD2"], step) return loss_dict def optimize_parameters(self, step): loss_dict = OrderedDict() # optimize trans if self.optim_deg: loss_dict = self.optimize_trans_models(loss_dict, step) # optimize SR if self.optim_sr: loss_dict = self.optimize_sr_models(loss_dict, step) self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, test_data): self.src = test_data["src"].to(self.device) if test_data.get("tgt") is not None: tgt = test_data["tgt"].to(self.device) b, c, h, w = tgt.shape crop_h = h // 8 * 8; crop_w = w // 8 * 8 self.tgt = tgt[:, :, :crop_h, :crop_w] self.set_network_state(["netDeg", "netSR"], "eval") with torch.no_grad(): self.fake_tgt = self.netSR(self.src) if test_data.get("tgt") is not None: self.fake_lr = self.netDeg(self.tgt) self.set_network_state(["netDeg", "netSR"], "train") def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.src.detach()[0].float().cpu() out_dict["sr"] = self.fake_tgt.detach()[0].float().cpu() if hasattr(self, "fake_lr"): out_dict["fake_lr"] = self.fake_lr.detach()[0].float().cpu() return out_dict class ShuffleBuffer(): """Random choose previous generated images or ones produced by the latest generators. :param buffer_size: the size of image buffer :type buffer_size: int """ def __init__(self, buffer_size): """Initialize the ImagePool class. :param buffer_size: the size of image buffer :type buffer_size: int """ self.buffer_size = buffer_size self.num_imgs = 0 self.images = [] def choose(self, images, prob=0.5): """Return an image from the pool. :param images: the latest generated images from the generator :type images: list :param prob: probability (0~1) of return previous images from buffer :type prob: float :return: Return images from the buffer :rtype: list """ if self.buffer_size == 0: return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.buffer_size: self.images.append(image) return_images.append(image) self.num_imgs += 1 else: p = random.uniform(0, 1) if p < prob: idx = random.randint(0, self.buffer_size - 1) stored_image = self.images[idx].clone() self.images[idx] = image return_images.append(stored_image) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: codes/config/DSGANSR/options/test/2017Track1.yml ================================================ #### general settings name: 2017Track1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /mnt/hdd/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /mnt/hdd/lzx/SRDatasets/NTIRE2020/track1_valid_input.lmdb # dataroot_tgt: /mnt/hdd/lzx/SRDatasets/NTIRE2020/track1_valid_gt.lmdb #### network structures networks: Encoder: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: log/2017Track1_deg/models/200000_Encoder.pth strict_load: true Decoder: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2017Track1/models/latest_Decoder.pth strict_load: true ================================================ FILE: codes/config/DSGANSR/options/test/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb networks: Encoder: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: log/2018Track2_deg/models/200000_Encoder.pth strict_load: true Decoder: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2018Track2/models/latest_Decoder.pth strict_load: true ================================================ FILE: codes/config/DSGANSR/options/test/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [6] metrics: [best_psnr, best_ssim, lpips] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb networks: Encoder: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: log/2018Track4_deg/models/200000_Encoder.pth strict_load: true Decoder: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2018Track4/models/latest_Decoder.pth strict_load: true ================================================ FILE: codes/config/DSGANSR/options/test/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] datasets: test1: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb networks: Encoder: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: log/2020Track1_deg/models/70000_Encoder.pth strict_load: true Decoder: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2020Track1/models/170000_Decoder.pth strict_load: true ================================================ FILE: codes/config/DSGANSR/options/train/deg/2017Track2.yml ================================================ #### general settings name: 2017Track2_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: log/2017Track1/models/195000_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_deg: false losses: color: type: ColorLoss ksize: 5 stride: 4 recursion: 1 loss_type: mse weight: 1.0 lr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.01 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.005 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.0 sr_pix_sr: type: L1Loss weight: 1.0 optimizers: deafault: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/deg/2018Track2.yml ================================================ #### general settings name: 2018Track2_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_deg: false losses: color: type: ColorLoss ksize: 5 stride: 4 recursion: 1 loss_type: mse weight: 1.0 lr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.01 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.005 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.0 sr_pix_sr: type: L1Loss weight: 1.0 optimizers: deafault: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/deg/2018Track4.yml ================================================ #### general settings name: 2018Track4_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [4] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: # path: ~ path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_deg: false losses: color: type: ColorLoss ksize: 5 stride: 4 recursion: 1 loss_type: mse weight: 1.0 lr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.01 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.005 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.0 sr_pix_sr: type: L1Loss weight: 1.0 optimizers: deafault: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/deg/2020Track1.yml ================================================ #### general settings name: 2020Track1_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_deg: false losses: color: type: ColorLoss ksize: 5 stride: 4 recursion: 1 loss_type: mse weight: 1.0 lr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.01 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.005 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.0 sr_pix_sr: type: L1Loss weight: 1.0 optimizers: deafault: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/sr/2017Track2.yml ================================================ #### general settings name: 2017Track2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true niter: 200000 warmup_iter: -1 # no warm up manual_seed: 0 val_freq: !!float 5e3 losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/sr/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true niter: 200000 warmup_iter: -1 # no warm up manual_seed: 0 val_freq: !!float 5e3 losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/sr/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [6] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true niter: 200000 warmup_iter: -1 # no warm up manual_seed: 0 val_freq: !!float 5e3 losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/options/train/sr/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [7] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: Encoder: which_network: Translator setting: nb: 16 nf: 64 scale: 0.25 zero_tail: true pretrain: path: ~ strict_load: true Decoder: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true niter: 200000 warmup_iter: -1 # no warm up manual_seed: 0 val_freq: !!float 5e3 losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/DSGANSR/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/DSGANSR/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/EDSR/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/EDSR/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/EDSR/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/EDSR/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/EDSR/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/EDSR/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, scale=4, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(in_nc, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(in_nc, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, out_nc, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.head(x) f = self.body(x) x = f + x x = self.tail(x) return x ================================================ FILE: codes/config/EDSR/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/EDSR/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/EDSR/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/EDSR/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/EDSR/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/EDSR/models/sr_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class SRModel(BaseModel): def __init__(self, opt): super().__init__(opt) self.data_names = ["lr", "hr"] self.network_names = ["netSR"] self.networks = {} self.loss_names = ["sr_adv", "sr_pix", "sr_percep"] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: # setup loss, optimizers, schedulers self.setup_train(opt["train"]) def feed_data(self, data): self.lr = data["src"].to(self.device) self.hr = data["tgt"].to(self.device) def forward(self): self.sr = self.netSR(self.lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() l_sr = 0 sr_pix = self.losses["sr_pix"](self.hr, self.sr) loss_dict["sr_pix"] = sr_pix l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], False) sr_adv_g = self.calculate_rgan_loss_G( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_g"] = sr_adv_g l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.hr, self.sr) loss_dict["sr_percep"] = sr_percep if sr_style is not None: loss_dict["sr_style"] = sr_style l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], True) sr_adv_d = self.calculate_rgan_loss_D( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_d"] = sr_adv_d self.optimizers["netD"].zero_grad() sr_adv_d.backward() self.optimizers["netD"].step() self.log_dict = self.reduce_loss_dict(loss_dict) def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data, crop_size=None): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): if crop_size is None: self.fake_real_hr = self.netSR(self.real_lr) else: self.fake_real_hr = self.crop_test(self.real_lr, crop_size) self.netSR.train() def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/EDSR/options/test/2017Track2_2020Track1.yml ================================================ #### general settings name: Bicubic_2017Track2_2020Track1 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nb: 16 nf: 64 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true ================================================ FILE: codes/config/EDSR/options/test/2018Track2_2020Track4.yml ================================================ #### general settings name: Bicubic_2018Track2_2018Track4 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nb: 16 nf: 64 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true ================================================ FILE: codes/config/EDSR/options/test/2020Track2.yml ================================================ #### general settings name: 2020Track2 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [niqe, piqe, brisque] datasets: test1: name: 2020Track2 mode: SingleDataset data_type: lmdb dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test.lmdb #### network structures networks: netSR: which_network: EDSR setting: nb: 16 nf: 64 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true ================================================ FILE: codes/config/EDSR/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: gt_img = None cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None cropped_gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/EDSR/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/Maeda/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/Maeda/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/Maeda/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/Maeda/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB2 = ResidualDenseBlock_5C(nf, gc) self.RDB3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) if upscale == 4: self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.trunk_conv(self.RRDB_trunk(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.upconv1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.HRconv(fea))) return out ================================================ FILE: codes/config/Maeda/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/Maeda/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, nb, nf, noise_nf=0, scale=4, zero_tail=False, conv=default_conv): super().__init__() self.scale = scale self.noise_nf = noise_nf # define head module if scale >= 1: m_head = [conv(3 + noise_nf, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(3 + noise_nf, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, 3, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) if zero_tail: nn.init.constant_(self.tail[-1].weight, 0) nn.init.constant_(self.tail[-1].bias, 0) def forward(self, x): if self.noise_nf > 0: b, c, h, w = x.shape noise = torch.randn(b, self.noise_nf, h, w).to(x.device) inp = torch.cat([x, noise], 1) else: inp = x f = self.head(inp) f = self.body(f) f = self.tail(f) if self.scale == 1: x = f + x else: x = f + F.interpolate(x, scale_factor=self.scale) return x ================================================ FILE: codes/config/Maeda/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/Maeda/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/Maeda/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}, crop_size=512) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/Maeda/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/Maeda/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/Maeda/models/pseudo_supervision_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class PseudoSupModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["syn_lr", "syn_hr", "real_lr"] self.network_names = ["netSR", "netG1", "netG2", "netD1", "netD2", "netD3"] self.networks = {} self.loss_names = [ "sr_pix", "srd3_adv", "g1d1_adv", "g2d2_adv", "g1g2_cycle", "g1_idt", "g2g1_cycle", "g2_idt" ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(opt["train"]) self.max_grad_norm = train_opt["max_grad_norm"] def feed_data(self, data): self.syn_lr = data["ref_src"].to(self.device) self.syn_hr = data["ref_tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def forward(self): self.fake_syn_lr = self.netG1(self.real_lr) self.rec_real_lr = self.netG2(self.fake_syn_lr) self.fake_real_lr = self.netG2(self.syn_lr) self.rec_syn_lr = self.netG1(self.fake_real_lr) self.fake_real_hr = self.netSR(self.fake_syn_lr) self.fake_syn_hr = self.netSR(self.rec_syn_lr) def optimize_parameters(self, step): loss_dict = OrderedDict() self.forward() loss_G = 0 self.set_requires_grad(["netD1", "netD2", "netD3"], False) g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["g1d1_adv"], self.syn_lr, self.fake_syn_lr ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["g1d1_adv"] * g1_adv_loss g2_adv_loss = self.calculate_gan_loss_G( self.netD2, self.losses["g2d2_adv"], self.real_lr, self.fake_real_lr ) loss_dict["g2_adv"] = g2_adv_loss.item() loss_G += self.loss_weights["g2d2_adv"] * g2_adv_loss g1g2_cycle = self.losses["g1g2_cycle"](self.rec_real_lr, self.real_lr) loss_dict["g1g2_cycle"] = g1g2_cycle.item() loss_G += self.loss_weights["g1g2_cycle"] * g1g2_cycle g2g1_cycle = self.losses["g2g1_cycle"](self.rec_syn_lr, self.syn_lr) loss_dict["g2g1_cycle"] = g2g1_cycle.item() loss_G += self.loss_weights["g2g1_cycle"] * g2g1_cycle if self.losses.get("g1_idt"): self.idt_syn_lr = self.netG1(self.syn_lr) g1_idt = self.losses["g1_idt"](self.idt_syn_lr, self.syn_lr) loss_dict["g1_idt"] = g1_idt.item() loss_G += self.loss_weights["g1_idt"] * g1_idt if self.losses.get("g2_idt"): self.idt_real_lr = self.netG2(self.real_lr) g2_idt = self.losses["g2_idt"](self.idt_real_lr, self.real_lr) loss_dict["g2_idt"] = g2_idt.item() loss_G += self.loss_weights["g2_idt"] * g2_idt sr_pix = self.losses["sr_pix"](self.fake_syn_hr, self.syn_hr) loss_dict["sr_pix"] = sr_pix.item() loss_G += self.loss_weights["sr_pix"] * sr_pix sr_adv = self.calculate_gan_loss_G( self.netD3, self.losses["srd3_adv"], self.syn_hr, self.fake_real_hr ) loss_dict["sr_adv"] = sr_adv.item() loss_G += self.loss_weights["srd3_adv"] * sr_adv self.set_optimizer( names=["netG1", "netG2", "netSR"], operation="zero_grad" ) loss_G.backward() self.set_optimizer(names=["netG1", "netG2", "netSR"], operation="step") ## update D1, D2, D3 self.set_requires_grad(["netD1", "netD2", "netD3"], True) loss_D = 0 loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["g1d1_adv"], self.syn_lr, self.fake_syn_lr ) loss_dict["d1_adv"] = loss_d1.item() loss_D += self.loss_weights["g1d1_adv"] * loss_d1 loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["g2d2_adv"], self.real_lr, self.fake_real_lr ) loss_dict["d2_adv"] = loss_d2.item() loss_D += self.loss_weights["g2d2_adv"] * loss_d2 loss_d3 = self.calculate_gan_loss_D( self.netD3, self.losses["srd3_adv"], self.syn_hr, self.fake_real_hr ) loss_dict["d3_adv"] = loss_d3.item() loss_D += self.loss_weights["srd3_adv"] * loss_d3 self.set_optimizer( names=["netD1", "netD2", "netD3"], operation="zero_grad" ) loss_D.backward() self.set_optimizer(names=["netD1", "netD2", "netD3"], operation="step") self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, data): self.real_lr = data["src"].to(self.device) self.set_network_state(["netSR", "netG1"], "eval") with torch.no_grad(): self.fake_syn_lr = self.netG1(self.real_lr) self.fake_real_hr = self.netSR(self.fake_syn_lr) self.set_network_state(["netSR", "netG1"], "train") def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/Maeda/options/test/2017Track2.yml ================================================ #### general settings name: 2017Track2 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] datasets: test1: name: 2017Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_mild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2017Track2/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2017Track2/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/Maeda/options/test/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [1] metrics: [best_psnr, best_ssim, lpips] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_wild.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track2/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2018Track2/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/Maeda/options/test/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb test4: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test5: # name: 2020Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/track1_valid_input.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track4/models/latest_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2018Track4/models/latest_netG1.pth strict_load: true ================================================ FILE: codes/config/Maeda/options/test/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [4] metrics: [psnr, ssim, lpips] datasets: # test1: # name: 2017Track1 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test2: # name: 2018Track2 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb # test3: # name: 2018Track3 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/valid_difficult.lmdb # dataroot_tgt: /home/lzx/SRDatasets/NTIRE2018/valid_HR.lmdb # test4: # name: 2018Track4 # mode: PairedDataset # data_type: lmdb # dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb # dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test5: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2020Track1/models/285000_netSR.pth strict_load: true netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2020Track1/models/285000_netG1.pth strict_load: true ================================================ FILE: codes/config/Maeda/options/train/2017Track2.yml ================================================ #### general settings name: 2017Track2 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD3: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 16 noise_nf: 1 zero_tail: true scale: 1 pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 losses: sr_pix: type: L1Loss weight: 1 srd3_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.1 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1g2_cycle: type: L1Loss weight: 1.0 g2g1_cycle: type: L1Loss weight: 1.0 g1_idt: type: L1Loss weight: 1 g2_idt: type: L1Loss weight: 1 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.9, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netD3: ~ niter: 300000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [100000, 180000, 240000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Maeda/options/train/2018Track2.yml ================================================ #### general settings name: 2018Track2 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2018Track2/models/210000_netSR.pth strict_load: true netD3: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track2/models/210000_netD3.pth strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2018Track2/models/210000_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track2/models/210000_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 16 noise_nf: 1 zero_tail: true scale: 1 pretrain: path: log/2018Track2/models/210000_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track2/models/210000_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 losses: sr_pix: type: L1Loss weight: 1 srd3_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.1 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1g2_cycle: type: L1Loss weight: 1.0 g2g1_cycle: type: L1Loss weight: 1.0 g1_idt: type: L1Loss weight: 1 g2_idt: type: L1Loss weight: 1 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.9, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netD3: ~ niter: 300000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [100000, 180000, 240000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Maeda/options/train/2018Track4.yml ================================================ #### general settings name: 2018Track4 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [4] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2018Track4/models/210000_netSR.pth strict_load: true netD3: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track4/models/210000_netD3.pth strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2018Track4/models/210000_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track4/models/210000_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 16 noise_nf: 1 zero_tail: true scale: 1 pretrain: path: log/2018Track4/models/210000_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2018Track4/models/210000_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 losses: sr_pix: type: L1Loss weight: 1 srd3_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.1 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1g2_cycle: type: L1Loss weight: 1.0 g2g1_cycle: type: L1Loss weight: 1.0 g1_idt: type: L1Loss weight: 1 g2_idt: type: L1Loss weight: 1 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.9, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netD3: ~ niter: 300000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [100000, 180000, 240000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Maeda/options/train/2020Track1.yml ================================================ #### general settings name: 2020Track1 use_tb_logger: false model: PseudoSupModel scale: 4 gpu_ids: [3] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: PairedRefDataset data_type: lmdb color: RGB ratios: [200, 50] dataroot_ref_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_ref_src: /home/lzx/SRDatasets/DIV2K_train/BicLR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 8 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: # path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt path: log/2020Track1/models/180000_netSR.pth strict_load: true netD3: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2020Track1/models/180000_netD3.pth strict_load: true #### network structures netG1: which_network: Translator setting: nf: 64 nb: 8 zero_tail: true scale: 1 pretrain: path: log/2020Track1/models/180000_netG1.pth strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2020Track1/models/180000_netD1.pth strict_load: true netG2: which_network: Translator setting: nf: 64 nb: 16 noise_nf: 1 zero_tail: true scale: 1 pretrain: path: log/2020Track1/models/180000_netG2.pth strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: log/2020Track1/models/180000_netD2.pth strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ max_grad_norm: 50 losses: sr_pix: type: L1Loss weight: 1 srd3_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.1 g1d1_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g2d2_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 g1g2_cycle: type: L1Loss weight: 1.0 g2g1_cycle: type: L1Loss weight: 1.0 g1_idt: type: L1Loss weight: 1 g2_idt: type: L1Loss weight: 1 optimizers: default: type: Adam lr: !!float 1e-4 betas: [0.9, 0.999] netSR: ~ netG1: ~ netG2: ~ netD1: ~ netD2: ~ netD3: ~ niter: 300000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [100000, 180000, 240000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/Maeda/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/Maeda/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0: if self.opt["spatial"]: zk = torch.randn(B, self.opt["nc"], H, W).to(x.device) else: zk = torch.randn(B, self.opt["nc"], 1, 1).to(x.device) if self.opt["mix"]: zk = zk.repeat(1, 1, H, W) if self.opt["mix"]: if self.opt["nc"] > 0: inp = torch.cat([x, zk], 1) else: inp = x else: inp = zk ksize = self.opt["ksize"] kernel = self.deg_kernel(inp).view(B, 1, ksize**2, *inp.shape[2:]) x = x.view(B*C, 1, H, W) x = F.unfold( self.pad(x), kernel_size=ksize, stride=self.scale, padding=0 ).view(B, C, ksize**2, h, w) x = torch.mul(x, kernel).sum(2).view(B, C, h, w) kernel = kernel.view(B, ksize, ksize, *inp.shape[2:]).squeeze() return x, kernel class NoiseModel(nn.Module): def __init__(self, opt, scale): super().__init__() self.scale = scale self.opt = opt nc, nf, nb = opt["nc"], opt["nf"], opt["nb"] if opt["spatial"]: head_k = opt["head_k"] body_k = opt["body_k"] else: head_k = body_k = 1 if opt["mix"]: in_nc = 3 + nc else: in_nc = nc deg_noise = [ nn.Conv2d(in_nc, nf, head_k, 1, head_k//2), nn.BatchNorm2d(nf), nn.ReLU(True), *[ ResBlock(nf=nf, ksize=body_k) for _ in range(nb) ], nn.Conv2d(nf, opt["dim"], 1, 1, 0), ] self.deg_noise = nn.Sequential(*deg_noise) if opt["zero_init"]: nn.init.constant_(self.deg_noise[-1].weight, 0) nn.init.constant_(self.deg_noise[-1].bias, 0) else: nn.init.normal_(self.deg_noise[-1].weight, 0.001) nn.init.constant_(self.deg_noise[-1].bias, 0) def forward(self, x): B, C, H, W = x.shape if self.opt["nc"] > 0: if self.opt["spatial"]: zn = torch.randn(x.shape[0], self.opt["nc"], H, W).to(x.device) else: zn = torch.randn(x.shape[0], self.opt["nc"], 1, 1).to(x.device) if self.opt["mix"]: zn = zn.repeat(1, 1, H, W) if self.opt["mix"]: if self.opt["nc"] > 0: inp = torch.cat([x, zn], 1) else: inp = x else: inp = zn noise = self.deg_noise(inp) return noise @ARCH_REGISTRY.register() class DegModel(nn.Module): def __init__( self, scale=4, nc_img=3, kernel_opt=None, noise_opt=None ): super().__init__() self.scale = scale self.kernel_opt = kernel_opt self.noise_opt = noise_opt if kernel_opt is not None: self.deg_kernel = KernelModel(kernel_opt, scale) if noise_opt is not None: self.deg_noise = NoiseModel(noise_opt, scale) else: self.quant = Quantization() def forward(self, inp): B, C, H, W = inp.shape h = H // self.scale w = W // self.scale # kernel if self.kernel_opt is not None: x, kernel = self.deg_kernel(inp) else: x = F.interpolate(inp, scale_factor=1/self.scale, mode="bicubic", align_corners=False) kernel = None # noise if self.noise_opt is not None: noise = self.deg_noise(x.detach()) x = x + noise else: noise = None x = self.quant(x) return x, kernel, noise ================================================ FILE: codes/config/PDM-SR/archs/discriminator.py ================================================ import torch import torchvision import functools import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm from utils.registry import ARCH_REGISTRY @ARCH_REGISTRY.register() class DiscriminatorVGG128(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512 * 4 * 4, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class DiscriminatorVGG32(nn.Module): def __init__(self, in_nc, nf): super().__init__() # [64, 128, 128] self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) self.bn0_1 = nn.BatchNorm2d(nf, affine=True) # [64, 64, 64] self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) # [128, 32, 32] self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) # [256, 16, 16] self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) # [512, 8, 8] self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) self.linear1 = nn.Linear(512, 100) self.linear2 = nn.Linear(100, 1) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.lrelu(self.conv0_0(x)) fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) fea = fea.view(fea.size(0), -1) fea = self.lrelu(self.linear1(fea)) out = self.linear2(fea) return out @ARCH_REGISTRY.register() class PatchGANDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, in_c, nf, nb, stride=1, norm_layer=nn.InstanceNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super().__init__() if ( type(norm_layer) == functools.partial ): # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 3 padw = 1 sequence = [ nn.Conv2d(in_c, nf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, nb): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=stride, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] nf_mult_prev = nf_mult nf_mult = min(2 ** nb, 8) sequence += [ nn.Conv2d( nf * nf_mult_prev, nf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias, ), norm_layer(nf * nf_mult), nn.LeakyReLU(0.2, True), ] sequence += [ nn.Conv2d(nf * nf_mult, nf, kernel_size=kw, stride=1, padding=padw) ] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) @ARCH_REGISTRY.register() class UNetDiscriminatorSN(nn.Module): """Defines a U-Net discriminator with spectral normalization (SN)""" def __init__(self, nc, nf=64, skip_connection=True): super(UNetDiscriminatorSN, self).__init__() self.skip_connection = skip_connection norm = spectral_norm self.conv0 = nn.Conv2d(nc, nf, kernel_size=3, stride=1, padding=1) self.conv1 = norm(nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False)) self.conv2 = norm(nn.Conv2d(nf * 2, nf * 4, 4, 2, 1, bias=False)) self.conv3 = norm(nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False)) # upsample self.conv4 = norm(nn.Conv2d(nf * 8, nf * 4, 3, 1, 1, bias=False)) self.conv5 = norm(nn.Conv2d(nf * 4, nf * 2, 3, 1, 1, bias=False)) self.conv6 = norm(nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=False)) # extra self.conv7 = norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=False)) self.conv8 = norm(nn.Conv2d(nf, nf, 3, 1, 1, bias=False)) self.conv9 = nn.Conv2d(nf, 1, 3, 1, 1) def forward(self, x): x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) # upsample x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) if self.skip_connection: x4 = x4 + x2 x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) if self.skip_connection: x5 = x5 + x1 x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) if self.skip_connection: x6 = x6 + x0 # extra out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) out = self.conv9(out) return out ================================================ FILE: codes/config/PDM-SR/archs/edsr.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1, ): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer @ARCH_REGISTRY.register() class EDSR(nn.Module): def __init__(self, nb, nf, res_scale=0.1, upscale=4, conv=default_conv): super(EDSR, self).__init__() n_resblocks = nb n_feats = nf kernel_size = 3 scale = upscale act = nn.ReLU(True) # url_name = 'r{}f{}x{}'.format(nb, nf, upscale) # if url_name in url: # self.url = url[url_name] # else: # self.url = None self.sub_mean = MeanShift(255.0, sign=-1) self.add_mean = MeanShift(255.0, sign=1) # define head module m_head = [conv(3, n_feats, kernel_size)] # define body module m_body = [ ResBlock(conv, n_feats, kernel_size, act=act, res_scale=res_scale) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x * 255.0) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) / 255.0 return x ================================================ FILE: codes/config/PDM-SR/archs/loss.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import lpips as lp from utils.registry import LOSS_REGISTRY from .vgg import VGGFeatureExtractor @LOSS_REGISTRY.register() class TVLoss(nn.Module): def __init__(self, penealty="L1Loss"): super().__init__() self.penealty = getattr(nn, penealty)() def forward(self, pred): y_diff = self.penealty(pred[:, :, :-1, :], pred[:, :, 1:, :]) x_diff = self.penealty(pred[:, :, :, :-1], pred[:, :, :, 1:]) loss = x_diff + y_diff return loss @LOSS_REGISTRY.register() class GaussGuided(nn.Module): def __init__(self, ksize, sigma): super().__init__() ax = torch.arange(0, ksize) - ksize//2 xx, yy = torch.meshgrid(ax, ax) dis = (xx ** 2 + yy ** 2) dis = torch.exp(-dis / sigma ** 2) dis = dis / dis.sum() self.register_buffer("gauss", dis.view(1, ksize**2, 1, 1)) def forward(self, kernel): return F.mse_loss(self.gauss, kernel) @LOSS_REGISTRY.register() class PerceptualLossLPIPS(nn.Module): def __init__(self, net="alex", normalize=True): super().__init__() self.fn = lp.LPIPS(net=net, spatial=True) for p in self.fn.parameters(): p.requires_grad = False self.normalize = normalize def forward(self, res, ref): return self.fn(res, ref, normalize=self.normalize).mean(), None @LOSS_REGISTRY.register() class MSELoss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.mse_loss(res, ref) @LOSS_REGISTRY.register() class L1Loss(nn.Module): def __init__(self, *args, **kwargs): super().__init__() def forward(self, res, ref): return F.l1_loss(res, ref) @LOSS_REGISTRY.register() class GANLoss(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): super(GANLoss, self).__init__() self.gan_type = gan_type self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == "vanilla": self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == "lsgan": self.loss = nn.MSELoss() elif self.gan_type == "wgan": self.loss = self._wgan_loss elif self.gan_type == "wgan_softplus": self.loss = self._wgan_softplus_loss elif self.gan_type == "hinge": self.loss = nn.ReLU() else: raise NotImplementedError(f"GAN type {self.gan_type} is not implemented.") def _wgan_loss(self, input, target): """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean() def _wgan_softplus_loss(self, input, target): """wgan loss with soft plus. softplus is a smooth approximation to the ReLU function. In StyleGAN2, it is called: Logistic loss for discriminator; Non-saturating loss for generator. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return F.softplus(-input).mean() if target else F.softplus(input).mean() def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, otherwise, return Tensor. """ if self.gan_type in ["wgan", "wgan_softplus"]: return target_is_real target_val = self.real_label_val if target_is_real else self.fake_label_val return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == "hinge": if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) return loss @LOSS_REGISTRY.register() class PerceptualLoss(nn.Module): """Perceptual loss with commonly used style loss. Args: layer_weights (dict): The weight for each layer of vgg feature. Here is an example: {'conv5_4': 1.}, which means the conv5_4 feature layer (before relu5_4) will be extracted with weight 1.0 in calculting losses. vgg_type (str): The type of vgg network used as feature extractor. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image in vgg. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. perceptual_weight (float): If `perceptual_weight > 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/PDM-SR/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop/ self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/PDM-SR/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/PDM-SR/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/PDM-SR/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock_5C(nf, gc) self.rdb2 = ResidualDenseBlock_5C(nf, gc) self.rdb3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.body = make_layer(RRDB_block_f, nb) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1) if upscale == 4: self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1) self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.conv_body(self.body(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.conv_up1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.conv_hr(fea))) return out ================================================ FILE: codes/config/PDM-SR/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/PDM-SR/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/PDM-SR/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/PDM-SR/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/PDM-SR/") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/PDM-SR/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/PDM-SR/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network, build_scheduler from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def setup_train(self, train_opt): # define losses loss_opt = train_opt["losses"] self.losses = self.build_losses(loss_opt) # build optmizers optimizer_opts = train_opt["optimizers"] self.optimizers = self.build_optimizers(optimizer_opts) # set schedulers scheduler_opts = train_opt["schedulers"] self.schedulers = self.build_schedulers(scheduler_opts) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) if isinstance(net, nn.Module): net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_losses(self, loss_opt): losses = {} defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") losses[name] = build_loss(loss_conf).to(self.device) return losses def build_optimizers(self, optim_opts): optimizers = {} if "default" in optim_opts.keys(): default_optim = optim_opts.pop("default") defined_optimizer_names = list(optim_opts.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_opt = optim_opts[name] if optim_opt is None: optim_opt = default_optim.copy() params = [] for v in self.networks[name].parameters(): if v.requires_grad: params.append(v) optim_type = optim_opt.pop("type") optimizer = getattr(torch.optim, optim_type)(params=params, **optim_opt) optimizers[name] = optimizer return optimizers def build_schedulers(self, scheduler_opts): """Set up scheduler.""" schedulers = {} if "default" in scheduler_opts.keys(): default_opt = scheduler_opts.pop("default") for name in self.optimizers.keys(): scheduler_opt = scheduler_opts[name] if scheduler_opt is None: scheduler_opt = default_opt.copy() schedulers[name] = build_scheduler(self.optimizers[name], scheduler_opt) return schedulers def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: if isinstance(self.networks[name], nn.Module): for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: if isinstance(self.networks[name], nn.Module): getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/PDM-SR/models/deg_sr_model.py ================================================ import logging from collections import OrderedDict import random import torch import torch.nn as nn from kornia.color import rgb_to_grayscale from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") class Quant(torch.autograd.Function): @staticmethod def forward(ctx, input): output = torch.clamp(input, 0, 1) output = (output * 255.).round() / 255. return output @staticmethod def backward(ctx, grad_output): return grad_output class Quantization(nn.Module): def __init__(self): super(Quantization, self).__init__() def forward(self, input): return Quant.apply(input) @MODEL_REGISTRY.register() class DegSRModel(BaseModel): def __init__(self, opt): super().__init__(opt) if opt["dist"]: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training self.data_names = ["src", "tgt"] self.network_names = ["netDeg", "netSR", "netD1", "netD2"] self.networks = {} self.loss_names = [ "lr_adv", "sr_adv", "sr_pix_trans", "sr_pix_sr", "sr_percep", "lr_quant", "lr_gauss", "noise_mean", "color" ] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # setup loss, optimizers, schedulers self.setup_train(opt["train"]) self.max_grad_norm = train_opt["max_grad_norm"] self.quant = Quantization() self.D_ratio = train_opt["D_ratio"] self.optim_sr = train_opt["optim_sr"] self.optim_deg = train_opt["optim_deg"] self.gray_dis = train_opt["gray_dis"] ## buffer self.fake_lr_buffer = ShuffleBuffer(train_opt["buffer_size"]) self.fake_hr_buffer = ShuffleBuffer(train_opt["buffer_size"]) def feed_data(self, data): self.syn_hr = data["tgt"].to(self.device) self.real_lr = data["src"].to(self.device) def deg_forward(self): ( self.fake_real_lr, self.predicted_kernel, self.predicted_noise, ) = self.netDeg(self.syn_hr) if self.losses.get("sr_pix_trans"): self.fake_real_lr_quant = self.quant(self.fake_real_lr) self.syn_sr = self.netSR(self.fake_real_lr_quant) def sr_forward(self): if not self.optim_deg: ( self.fake_real_lr, self.predicted_kernel, self.predicted_noise, ) = self.netDeg(self.syn_hr) self.fake_real_lr_quant = self.quant(self.fake_real_lr) self.syn_sr = self.netSR(self.fake_real_lr_quant.detach()) def optimize_trans_models(self, step, loss_dict): self.set_requires_grad(["netDeg"], True) self.deg_forward() loss_G = 0 if self.losses.get("lr_adv"): self.set_requires_grad(["netD1"], False) if self.gray_dis: real = rgb_to_grayscale(self.real_lr) fake = rgb_to_grayscale(self.fake_real_lr) else: real = self.real_lr fake = self.fake_real_lr g1_adv_loss = self.calculate_gan_loss_G( self.netD1, self.losses["lr_adv"], real, fake ) loss_dict["g1_adv"] = g1_adv_loss.item() loss_G += self.loss_weights["lr_adv"] * g1_adv_loss if self.losses.get("sr_pix_trans"): self.set_requires_grad(["netSR"], False) sr_pix = self.losses["sr_pix_trans"](self.syn_hr, self.syn_sr) loss_dict["sr_pix_trans"] = sr_pix.item() loss_G += self.loss_weights["sr_pix_trans"] * sr_pix if self.losses.get("noise_mean"): noise = self.predicted_noise noise_mean = ( self.losses["noise_mean"](noise, torch.zeros_like(noise)) ) loss_dict["noise_mean"] = noise_mean.item() loss_G += self.loss_weights["noise_mean"] * noise_mean self.set_optimizer(names=["netDeg"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(["netDeg"], self.max_grad_norm) self.set_optimizer(names=["netDeg"], operation="step") ## update D if step % self.D_ratio == 0: self.set_requires_grad(["netD1"], True) if self.gray_dis: real = rgb_to_grayscale(self.real_lr) fake = rgb_to_grayscale(self.fake_real_lr) else: real = self.real_lr fake = self.fake_real_lr loss_d1 = self.calculate_gan_loss_D( self.netD1, self.losses["lr_adv"], real, self.fake_lr_buffer.choose(fake) ) loss_dict["d1_adv"] = loss_d1.item() loss_D = self.loss_weights["lr_adv"] * loss_d1 self.optimizers["netD1"].zero_grad() loss_D.backward() self.clip_grad_norm(["netD1"], self.max_grad_norm) self.optimizers["netD1"].step() return loss_dict def optimize_sr_models(self, step, loss_dict): self.set_requires_grad(["netSR"], True) self.set_requires_grad(["netDeg"], False) self.sr_forward() loss_G = 0 if self.losses.get("sr_adv"): self.set_requires_grad(["netD2"], False) sr_adv_loss = self.calculate_gan_loss_G( self.netD2, self.losses["sr_adv"], self.syn_hr, self.syn_sr ) loss_dict["sr_adv"] = sr_adv_loss.item() loss_G += self.loss_weights["sr_adv"] * sr_adv_loss if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"]( self.syn_hr, self.syn_sr ) loss_dict["sr_percep"] = sr_percep.item() if sr_style is not None: loss_dict["sr_style"] = sr_style.item() loss_G += self.loss_weights["sr_percep"] * sr_style loss_G += self.loss_weights["sr_percep"] * sr_percep if self.losses.get("sr_pix_sr"): sr_pix = self.losses["sr_pix_sr"](self.syn_hr, self.syn_sr) loss_dict["sr_pix_sr"] = sr_pix.item() loss_G += self.loss_weights["sr_pix_sr"] * sr_pix self.set_optimizer(names=["netSR"], operation="zero_grad") loss_G.backward() self.clip_grad_norm(["netSR"], self.max_grad_norm) self.set_optimizer(names=["netSR"], operation="step") ## update D2 if step % self.D_ratio == 0: if self.losses.get("sr_adv"): self.set_requires_grad(["netD2"], True) loss_d2 = self.calculate_gan_loss_D( self.netD2, self.losses["sr_adv"], self.syn_hr, self.fake_hr_buffer.choose(self.syn_sr) ) loss_dict["d2_adv"] = loss_d2.item() loss_D = self.loss_weights["sr_adv"] * loss_d2 self.optimizers["netD2"].zero_grad() loss_D.backward() self.clip_grad_norm(["netD2"], self.max_grad_norm) self.optimizers["netD2"].step() return loss_dict def optimize_parameters(self, step): loss_dict = OrderedDict() # optimize trans if self.optim_deg: loss_dict = self.optimize_trans_models(step, loss_dict) # optimize SR if self.optim_sr: loss_dict = self.optimize_sr_models(step, loss_dict) self.log_dict = loss_dict def calculate_gan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion(d_pred_real, True, is_disc=True) loss_fake = criterion(d_pred_fake, False, is_disc=True) return (loss_real + loss_fake) / 2 def calculate_gan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) loss_real = criterion(d_pred_fake, True, is_disc=False) return loss_real def test(self, test_data, crop_size=None): self.src = test_data["src"].to(self.device) if test_data.get("tgt") is not None: self.tgt = test_data["tgt"].to(self.device) self.set_network_state(["netSR"], "eval") with torch.no_grad(): if crop_size is None: self.fake_tgt = self.netSR(self.src) else: self.fake_tgt = self.crop_test(self.src, crop_size) self.set_network_state(["netSR"], "train") if hasattr(self, "netDeg"): self.set_network_state(["netDeg"], "eval") if hasattr(self, "tgt"): with torch.no_grad(): self.fake_lr = self.netDeg(self.tgt)[0] self.set_network_state(["netDeg"], "train") def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.src.detach()[0].float().cpu() out_dict["sr"] = self.fake_tgt.detach()[0].float().cpu() if hasattr(self, "fake_lr"): out_dict["fake_lr"] = self.fake_lr.detach()[0].float().cpu() return out_dict def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr class ShuffleBuffer(): """Random choose previous generated images or ones produced by the latest generators. :param buffer_size: the size of image buffer :type buffer_size: int """ def __init__(self, buffer_size): """Initialize the ImagePool class. :param buffer_size: the size of image buffer :type buffer_size: int """ self.buffer_size = buffer_size self.num_imgs = 0 self.images = [] def choose(self, images, prob=0.5): """Return an image from the pool. :param images: the latest generated images from the generator :type images: list :param prob: probability (0~1) of return previous images from buffer :type prob: float :return: Return images from the buffer :rtype: list """ if self.buffer_size == 0: return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.buffer_size: self.images.append(image) return_images.append(image) self.num_imgs += 1 else: p = random.uniform(0, 1) if p < prob: idx = random.randint(0, self.buffer_size - 1) stored_image = self.images[idx].clone() self.images[idx] = image return_images.append(stored_image) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: codes/config/PDM-SR/options/test/2017Track1.yml ================================================ #### general settings name: 2017Track1_percep use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 body_k: 1 head_k: 1 ksize: 21 zero_init: true noise_opt: ~ pretrain: path: log/2017Track1/2017Track1_deg_best/models/latest_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2017Track1/2017Track1_percep_best/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/PDM-SR/options/test/2018Track2.yml ================================================ #### general settings name: 2018Track2_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test0: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: log/2018Track2/2018Track2_deg/models/195000_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track2/2018Track2_psnr/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/PDM-SR/options/test/2018Track4.yml ================================================ #### general settings name: 2018Track4_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [best_psnr, best_ssim, lpips, niqe, piqe, brisque] datasets: test0: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 3 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: log/2018Track4_deg/models/latest_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: log/2018Track4_psnr/models/latest_netSR.pth strict_load: true ================================================ FILE: codes/config/PDM-SR/options/test/2020Track1.yml ================================================ #### general settings name: 2020Track1_percep_bsrgan use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test0: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: spatial: false nc: 3 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: spatial: false nc: 3 nf: 32 nb: 8 head_k: 3 body_k: 3 dim: 1 zero_init: false pretrain: path: log/2020Track1_deg/models/latest_netDeg.pth strict_load: true netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ~ strict_load: true ================================================ FILE: codes/config/PDM-SR/options/test/2020Track2.yml ================================================ ## general settings name: 2020Track2_percep use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [7] metrics: [niqe] datasets: test0: name: 2020Track2 mode: SingleDataset data_type: lmdb dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/PDM_Real_ESRGAN.pth strict_load: true ================================================ FILE: codes/config/PDM-SR/options/train/deg/2017Track1.yml ================================================ #### general settings name: 2017Track1_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 192 src_size: 48 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 body_k: 1 head_k: 1 ksize: 21 zero_init: true noise_opt: ~ pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: false losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 optimizers: netDeg: type: Adam lr: !!float 2e-4 netD1: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 2e5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/deg/2018Track2.yml ================================================ #### general settings name: 2018Track2_deg_mse10_mixfale use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 192 src_size: 48 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: false nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: false losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 noise_mean: type: MSELoss weight: 10.0 optimizers: netDeg: type: Adam lr: !!float 2e-4 netD1: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 2e5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/deg/2018Track4.yml ================================================ #### general settings name: 2018Track4_deg use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 3 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: false losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 noise_mean: type: MSELoss weight: 100.0 optimizers: netDeg: type: Adam lr: !!float 2e-4 netD1: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 2e5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/deg/2020Track1.yml ================================================ #### general settings name: 2020Track1_deg_dim1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: spatial: false nc: 3 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: spatial: false nc: 3 nf: 32 nb: 8 head_k: 3 body_k: 3 dim: 1 zero_init: false pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: false losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 noise_mean: type: MSELoss weight: 100.0 optimizers: netDeg: type: Adam lr: !!float 2e-4 netD1: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 2e5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/deg/2020Track2.yml ================================================ #### general settings name: 2020Track2_deg_dim1 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track2/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track2 mode: SingleImageDataset data_type: lmdb color: RGB dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: spatial: false nc: 3 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: spatial: false nc: 3 nf: 32 nb: 8 head_k: 3 body_k: 3 dim: 1 zero_init: false pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 2 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: false losses: lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 noise_mean: type: MSELoss weight: 1.0 optimizers: netDeg: type: Adam lr: !!float 2e-4 netD1: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 2e5 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/percep/2017Track1.yml ================================================ #### general settings name: 2017Track1_percep_best use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track1_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 body_k: 1 head_k: 1 ksize: 21 zero_init: true noise_opt: ~ pretrain: path: log/2017Track1_deg_best/models/latest_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_sr: true optim_deg: false losses: sr_pix_sr: type: L1Loss weight: 1.0 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.05 sr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.05 optimizers: netSR: type: Adam lr: !!float 2e-4 netD2: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/percep/2018Track2.yml ================================================ #### general settings name: 2018Track2_percep use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [3] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 1.0 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.05 sr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.05 optimizers: netSR: type: Adam lr: !!float 2e-4 netD2: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/percep/2018Track4.yml ================================================ #### general settings name: 2018Track4_percep use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 3 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_sr: true optim_deg: false losses: sr_pix_sr: type: L1Loss weight: 1.0 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.05 sr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.05 optimizers: netSR: type: Adam lr: !!float 2e-4 netD2: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/percep/2020Track1.yml ================================================ #### general settings name: 2020Track1_percep_bsrgan use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track1 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: spatial: false nc: 3 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: spatial: false nc: 3 nf: 32 nb: 8 head_k: 3 body_k: 3 dim: 1 zero_init: false pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 0.01 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.005 sr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.05 optimizers: netSR: type: Adam lr: !!float 2e-4 netD2: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/percep/2020Track2.yml ================================================ #### general settings name: 2020Track2_percep_bsrgan use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [5] metrics: [niqe] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track2/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track2 mode: SingleDataset data_type: lmdb color: RGB dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: spatial: false nc: 3 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: spatial: false nc: 3 nf: 32 nb: 8 head_k: 3 body_k: 3 dim: 1 zero_init: false pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netD2: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: true optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 1.0 noise_mean: type: MSELoss weight: 1.0 sr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 0.05 sr_percep: type: PerceptualLoss layer_weights: 'conv5_4': 1 # before relu vgg_type: vgg19 use_input_norm: true range_norm: false perceptual_weight: 1.0 style_weight: 0 criterion: l1 weight: !!float 0.05 optimizers: default: type: Adam lr: !!float 2e-4 netSR: ~ netD2: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/psnr/2017Track2.yml ================================================ #### general settings name: 2017Track2_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/train_LR/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2017Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 body_k: 1 head_k: 1 ksize: 21 zero_init: true noise_opt: ~ pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_sr: true optim_deg: true losses: sr_pix_sr: type: L1Loss weight: 1.0 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 optimizers: default: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/psnr/2018Track2.yml ================================================ #### general settings name: 2018Track2_psnr_v2 use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [0] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [200, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4_half.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/x4_half.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 192 src_size: 48 use_flip: true use_rot: true val: name: 2018Track2 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: netSR: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/psnr/2018Track4.yml ================================================ #### general settings name: 2018Track4_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [1] metrics: [best_psnr, best_ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/x4.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2018Track4 mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 3 nf: 64 nb: 16 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 16 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: netSR: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/psnr/2020Track1.yml ================================================ #### general settings name: 2020Track1_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [1] metrics: [psnr, ssim, lpips] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/train_source.lmdb use_shuffle: true workers_per_gpu: 4 # per GPU imgs_per_gpu: 32 tgt_size: 128 src_size: 32 use_flip: true use_rot: true val: name: 2020Track2_mini mode: PairedDataset data_type: lmdb color: RGB dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid_mini.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 3 nf: 64 nb: 8 body_k: 1 head_k: 1 ksize: 11 zero_init: true noise_opt: mix: true spatial: true nc: 3 nf: 64 nb: 8 body_k: 3 head_k: 3 dim: 3 zero_init: true pretrain: path: ~ strict_load: true netD1: which_network: PatchGANDiscriminator setting: in_c: 3 nf: 64 nb: 3 stride: 1 pretrain: path: ~ strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_sr: true optim_deg: true losses: sr_pix_sr: type: L1Loss weight: 1.0 lr_adv: type: GANLoss gan_type: lsgan real_label_val: 1.0 fake_label_val: 0.0 weight: !!float 1.0 noise_mean: type: MSELoss weight: !!float 100 optimizers: default: type: Adam lr: !!float 2e-4 netDeg: ~ netSR: ~ netD1: ~ niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/options/train/psnr/2020Track2.yml ================================================ #### general settings name: 2020Track2_psnr use_tb_logger: false model: DegSRModel scale: 4 gpu_ids: [2] metrics: [niqe] #### datasets datasets: train: name: DIV2K mode: UnPairedDataset data_type: lmdb color: RGB ratios: [50, 200] dataroot_tgt: /home/lzx/SRDatasets/DIV2K_train/HR/x4.lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track2/train_source.lmdb use_shuffle: true workers_per_gpu: 6 # per GPU imgs_per_gpu: 32 tgt_size: 192 src_size: 48 use_flip: true use_rot: true val: name: 2020Track2 mode: SingleDataset data_type: lmdb color: RGB dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test_mini.lmdb #### network structures networks: netDeg: which_network: DegModel setting: scale: 4 nc_img: 3 kernel_opt: mix: false spatial: false nc: 64 nf: 64 nb: 8 head_k: 1 body_k: 1 ksize: 21 zero_init: true noise_opt: mix: true nc: 3 nf: 64 nb: 8 head_k: 3 body_k: 3 dim: 3 zero_init: true pretrain: path: log/2020Track2_deg_mse10/models/195000_netDeg.pth strict_load: true netSR: which_network: EDSR setting: nf: 64 nb: 16 res_scale: 1 upscale: 4 pretrain: path: ../../../checkpoints/EDSR/edsr_baseline_x4-new.pt strict_load: true #### training settings: learning rate scheme, loss train: resume_state: ~ D_ratio: 1 max_grad_norm: 50 buffer_size: 0 optim_deg: false optim_sr: true losses: sr_pix_sr: type: L1Loss weight: 1.0 optimizers: netSR: type: Adam lr: !!float 2e-4 niter: 200000 warmup_iter: -1 # no warm up schedulers: default: type: MultiStepRestartLR milestones: [50000, 100000, 150000] gamma: 0.5 manual_seed: 0 val_freq: !!float 5e3 #### logger logger: print_freq: 100 save_checkpoint_freq: !!float 5e3 ================================================ FILE: codes/config/PDM-SR/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: gt_img = None cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None cropped_gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/PDM-SR/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " 0`, the perceptual loss will be calculated and the loss will multiplied by the weight. Default: 1.0. style_weight (float): If `style_weight > 0`, the style loss will be calculated and the loss will multiplied by the weight. Default: 0. criterion (str): Criterion used for perceptual loss. Default: 'l1'. """ def __init__( self, layer_weights, vgg_type="vgg19", use_input_norm=True, range_norm=False, perceptual_weight=1.0, style_weight=0.0, criterion="l1", ): super(PerceptualLoss, self).__init__() self.perceptual_weight = perceptual_weight self.style_weight = style_weight self.layer_weights = layer_weights self.vgg = VGGFeatureExtractor( layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm, range_norm=range_norm, ) self.criterion_type = criterion if self.criterion_type == "l1": self.criterion = torch.nn.L1Loss() elif self.criterion_type == "l2": self.criterion = torch.nn.L2loss() elif self.criterion_type == "fro": self.criterion = None else: raise NotImplementedError(f"{criterion} criterion has not been supported.") def forward(self, x, gt): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). gt (Tensor): Ground-truth tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ # extract vgg features x_features = self.vgg(x) gt_features = self.vgg(gt.detach()) # calculate perceptual loss if self.perceptual_weight > 0: percep_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": percep_loss += ( torch.norm(x_features[k] - gt_features[k], p="fro") * self.layer_weights[k] ) else: percep_loss += ( self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] ) percep_loss *= self.perceptual_weight else: percep_loss = None # calculate style loss if self.style_weight > 0: style_loss = 0 for k in x_features.keys(): if self.criterion_type == "fro": style_loss += ( torch.norm( self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p="fro", ) * self.layer_weights[k] ) else: style_loss += ( self.criterion( self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]), ) * self.layer_weights[k] ) style_loss *= self.style_weight else: style_loss = None return percep_loss, style_loss def _gram_mat(self, x): """Calculate Gram matrix. Args: x (torch.Tensor): Tensor with shape of (n, c, h, w). Returns: torch.Tensor: Gram matrix. """ n, c, h, w = x.size() features = x.view(n, c, w * h) features_t = features.transpose(1, 2) gram = features.bmm(features_t) / (c * h * w) return gram @LOSS_REGISTRY.register() class CharbonnierLoss(nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6): super(CharbonnierLoss, self).__init__() self.eps = eps def forward(self, x, y): diff = x - y loss = torch.mean(torch.sqrt(diff * diff + self.eps)) return loss class GradientPenaltyLoss(nn.Module): def __init__(self, device=torch.device("cpu")): super(GradientPenaltyLoss, self).__init__() self.register_buffer("grad_outputs", torch.Tensor()) self.grad_outputs = self.grad_outputs.to(device) def get_grad_outputs(self, input): if self.grad_outputs.size() != input.size(): self.grad_outputs.resize_(input.size()).fill_(1.0) return self.grad_outputs def forward(self, interp, interp_crit): grad_outputs = self.get_grad_outputs(interp_crit) grad_interp = torch.autograd.grad( outputs=interp_crit, inputs=interp, grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True, )[0] grad_interp = grad_interp.view(grad_interp.size(0), -1) grad_interp_norm = grad_interp.norm(2, dim=1) loss = ((grad_interp_norm - 1) ** 2).mean() return loss ================================================ FILE: codes/config/RealESRGAN/archs/lr_scheduler.py ================================================ import math from collections import Counter, defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler from utils.registry import LR_SCHEDULER_REGISTRY @LR_SCHEDULER_REGISTRY.register() class LinearDecayLR(_LRScheduler): def __init__( self, optimizer, decay_prop, total_steps, last_epoch=-1, ): self.decay_prop = decay_prop self.total_steps = total_steps super().__init__(optimizer, last_epoch) def get_lr(self): return [ group["initial_lr"] * (1 - (self.last_epoch + 1) * self.decay_prop / self.total_steps) for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class MultiStepRestartLR(_LRScheduler): def __init__( self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, clear_state=False, last_epoch=-1, ): self.milestones = Counter(milestones) self.gamma = gamma self.clear_state = clear_state self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch in self.restarts: if self.clear_state: self.optimizer.state = defaultdict(dict) weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] if self.last_epoch not in self.milestones: return [group["lr"] for group in self.optimizer.param_groups] return [ group["lr"] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups ] @LR_SCHEDULER_REGISTRY.register() class CosineAnnealingRestartLR(_LRScheduler): def __init__( self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1 ): self.T_period = T_period self.T_max = self.T_period[0] # current T period self.eta_min = eta_min self.restarts = restarts if restarts else [0] self.restart_weights = weights if weights else [1] self.last_restart = 0 assert len(self.restarts) == len( self.restart_weights ), "restarts and their weights do not match." super().__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch == 0: return self.base_lrs elif self.last_epoch in self.restarts: self.last_restart = self.last_epoch self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] weight = self.restart_weights[self.restarts.index(self.last_epoch)] return [ group["initial_lr"] * weight for group in self.optimizer.param_groups ] elif (self.last_epoch - self.last_restart - 1 - self.T_max) % ( 2 * self.T_max ) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) ] return [ (1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / ( 1 + math.cos( math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max ) ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups ] ================================================ FILE: codes/config/RealESRGAN/archs/module_util.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): net_l = [net_l] for net in net_l: for m in net.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale # for residual block if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Linear): init.kaiming_normal_(m.weight, a=0, mode="fan_in") m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias.data, 0.0) def make_layer(block, n_layers): layers = [] for _ in range(n_layers): layers.append(block()) return nn.Sequential(*layers) class ResidualBlock_noBN(nn.Module): """Residual block w/o BN ---Conv-ReLU-Conv-+- |________________| """ def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # initialization initialize_weights([self.conv1, self.conv2], 0.1) def forward(self, x): identity = x out = F.relu(self.conv1(x), inplace=True) out = self.conv2(out) return identity + out def flow_warp(x, flow, interp_mode="bilinear", padding_mode="zeros"): """Warp an image or feature map with optical flow Args: x (Tensor): size (N, C, H, W) flow (Tensor): size (N, H, W, 2), normal value interp_mode (str): 'nearest' or 'bilinear' padding_mode (str): 'zeros' or 'border' or 'reflection' Returns: Tensor: warped image or feature map """ assert x.size()[-2:] == flow.size()[1:3] B, C, H, W = x.size() # mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 grid.requires_grad = False grid = grid.type_as(x) vgrid = grid + flow # scale grid to [-1,1] vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) return output ================================================ FILE: codes/config/RealESRGAN/archs/rcan.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid(), ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) # res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks ): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ) for _ in range(n_resblocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) @ARCH_REGISTRY.register() class RCAN(nn.Module): def __init__(self, ng, nb, nf, reduction=16, upscale=4, conv=default_conv): super(RCAN, self).__init__() n_resgroups = ng n_resblocks = nb n_feats = nf kernel_size = 3 reduction = reduction scale = upscale act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = MeanShift(1.0, rgb_mean, rgb_std, -1) # define head module modules_head = [conv(3, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=1.0, n_resblocks=nb, ) for _ in range(ng) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size), ] self.add_mean = MeanShift(1.0, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find("tail") >= 0: print("Replace pre-trained upsampler to new one...") else: raise RuntimeError( "While copying the parameter named {}, " "whose dimensions in the model are {} and " "whose dimensions in the checkpoint are {}.".format( name, own_state[name].size(), param.size() ) ) elif strict: if name.find("tail") == -1: raise KeyError('unexpected key "{}" in state_dict'.format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: codes/config/RealESRGAN/archs/rrdb.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * class ResidualDenseBlock_5C(nn.Module): def __init__(self, nf=64, gc=32, bias=True): super(ResidualDenseBlock_5C, self).__init__() # gc: growth channel, i.e. intermediate channels self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # initialization initialize_weights( [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1 ) def forward(self, x): x1 = self.lrelu(self.conv1(x)) x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """Residual in Residual Dense Block""" def __init__(self, nf, gc=32): super(RRDB, self).__init__() self.rdb1 = ResidualDenseBlock_5C(nf, gc) self.rdb2 = ResidualDenseBlock_5C(nf, gc) self.rdb3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): out = self.rdb1(x) out = self.rdb2(out) out = self.rdb3(out) return out * 0.2 + x @ARCH_REGISTRY.register() class RRDBNet(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4): super(RRDBNet, self).__init__() self.upscale = upscale RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) self.body = make_layer(RRDB_block_f, nb) self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) #### upsampling self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1) if upscale == 4: self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1) self.conv_hr = nn.Conv2d(nf, nf, 3, 1, 1) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) trunk = self.conv_body(self.body(fea)) fea = fea + trunk if self.upscale == 2 or self.upscale == 3: fea = self.lrelu( self.conv_up1( F.interpolate(fea, scale_factor=self.upscale, mode="nearest") ) ) if self.upscale == 4: fea = self.lrelu( self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) fea = self.lrelu( self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest")) ) out = self.conv_last(self.lrelu(self.conv_hr(fea))) return out ================================================ FILE: codes/config/RealESRGAN/archs/srresnet.py ================================================ import functools from utils.registry import ARCH_REGISTRY from .module_util import * @ARCH_REGISTRY.register() class MSRResNet(nn.Module): """modified SRResNet""" def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): super(MSRResNet, self).__init__() self.upscale = upscale self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) basic_block = functools.partial(ResidualBlock_noBN, nf=nf) self.recon_trunk = make_layer(basic_block, nb) # upsampling if self.upscale == 2: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) elif self.upscale == 3: self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(3) elif self.upscale == 4: self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) # activation function self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) # initialization initialize_weights( [self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1 ) if self.upscale == 4: initialize_weights(self.upconv2, 0.1) def forward(self, x): fea = self.lrelu(self.conv_first(x)) out = self.recon_trunk(fea) if self.upscale == 4: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) elif self.upscale == 3 or self.upscale == 2: out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) out = self.conv_last(self.lrelu(self.HRconv(out))) base = F.interpolate( x, scale_factor=self.upscale, mode="bilinear", align_corners=False ) out += base return out ================================================ FILE: codes/config/RealESRGAN/archs/translator.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Variable from utils.registry import ARCH_REGISTRY def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias ) class BasicBlock(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, bias=False, bn=True, act=nn.ReLU(True), ): m = [ nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias, ) ] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feat, kernel_size, bias=True, bn=False, act=nn.ReLU(True), res_scale=1, ): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feat)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feat)) if act: m.append(act()) elif scale == 1: m.append(nn.Identity()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) @ARCH_REGISTRY.register() class Translator(nn.Module): def __init__(self, in_nc, out_nc, nf, nb, scale=4, conv=default_conv): super().__init__() self.scale = scale # define head module if scale >= 1: m_head = [conv(in_nc, nf, 3)] else: s = int(1 / scale) m_head = [nn.Conv2d(in_nc, nf, kernel_size=2 * s + 1, stride=s, padding=s)] # define body module m_body = [ ResBlock(conv, nf, 3, act=nn.ReLU(True), res_scale=1) for _ in range(nb) ] m_body.append(conv(nf, nf, 3)) # define tail module m_tail = [ Upsampler(conv, scale, nf, act=False) if scale > 1 else nn.Identity(), conv(nf, out_nc, 3), ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.head(x) f = self.body(x) x = f + x x = self.tail(x) return x ================================================ FILE: codes/config/RealESRGAN/archs/vgg.py ================================================ import os from collections import OrderedDict import torch from torch import nn as nn from torchvision.models import vgg as vgg from utils.registry import ARCH_REGISTRY VGG_PRETRAIN_PATH = "checkpoints/pretrained_models/vgg19-dcbb9e9d.pth" NAMES = { "vgg11": [ "conv1_1", "relu1_1", "pool1", "conv2_1", "relu2_1", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg13": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "pool5", ], "vgg16": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5", ], "vgg19": [ "conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1", "conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2", "conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3", "conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4", "conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5", ], } def insert_bn(names): """Insert bn layer after each conv. Args: names (list): The list of layer names. Returns: list: The list of layer names with bn layers. """ names_bn = [] for name in names: names_bn.append(name) if "conv" in name: position = name.replace("conv", "") names_bn.append("bn" + position) return names_bn @ARCH_REGISTRY.register() class VGGFeatureExtractor(nn.Module): """VGG network for feature extraction. In this implementation, we allow users to choose whether use normalization in the input feature and the type of vgg network. Note that the pretrained path must fit the vgg type. Args: layer_name_list (list[str]): Forward function returns the corresponding features according to the layer_name_list. Example: {'relu1_1', 'relu2_1', 'relu3_1'}. vgg_type (str): Set the type of vgg network. Default: 'vgg19'. use_input_norm (bool): If True, normalize the input image. Importantly, the input feature must in the range [0, 1]. Default: True. range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. Default: False. requires_grad (bool): If true, the parameters of VGG network will be optimized. Default: False. remove_pooling (bool): If true, the max pooling operations in VGG net will be removed. Default: False. pooling_stride (int): The stride of max pooling operation. Default: 2. """ def __init__( self, layer_name_list, vgg_type="vgg19", use_input_norm=True, range_norm=False, requires_grad=False, remove_pooling=False, pooling_stride=2, ): super(VGGFeatureExtractor, self).__init__() self.layer_name_list = layer_name_list self.use_input_norm = use_input_norm self.range_norm = range_norm self.names = NAMES[vgg_type.replace("_bn", "")] if "bn" in vgg_type: self.names = insert_bn(self.names) # only borrow layers that will be used to avoid unused params max_idx = 0 for v in layer_name_list: idx = self.names.index(v) if idx > max_idx: max_idx = idx if os.path.exists(VGG_PRETRAIN_PATH): vgg_net = getattr(vgg, vgg_type)(pretrained=False) state_dict = torch.load( VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage ) vgg_net.load_state_dict(state_dict) else: vgg_net = getattr(vgg, vgg_type)(pretrained=True) features = vgg_net.features[: max_idx + 1] modified_net = OrderedDict() for k, v in zip(self.names, features): if "pool" in k: # if remove_pooling is true, pooling operation will be removed if remove_pooling: continue else: # in some cases, we may want to change the default stride modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) else: modified_net[k] = v self.vgg_net = nn.Sequential(modified_net) if not requires_grad: self.vgg_net.eval() for param in self.parameters(): param.requires_grad = False else: self.vgg_net.train() for param in self.parameters(): param.requires_grad = True if self.use_input_norm: # the mean is for image with range [0, 1] self.register_buffer( "mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) # the std is for image with range [0, 1] self.register_buffer( "std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): """Forward function. Args: x (Tensor): Input tensor with shape (n, c, h, w). Returns: Tensor: Forward results. """ if self.range_norm: x = (x + 1) / 2 if self.use_input_norm: x = (x - self.mean) / self.std output = {} for key, layer in self.vgg_net._modules.items(): x = layer(x) if key in self.layer_name_list: output[key] = x.clone() return output ================================================ FILE: codes/config/RealESRGAN/count_flops.py ================================================ import argparse import sys import torch from torchsummaryX import summary sys.path.append("../../") import utils.option as option from models import create_model parser = argparse.ArgumentParser() parser.add_argument( "--opt", type=str, default="options/setting1/test/test_setting1_x4.yml", help="Path to option YMAL file of Predictor.", ) args = parser.parse_args() opt = option.parse(args.opt, root_path=".", is_train=True) opt = option.dict_to_nonedict(opt) model = create_model(opt) test_tensor = torch.randn(1, 3, 270, 180).cuda() for name, net in model.networks.items(): summary(net.cuda(), x=test_tensor) print("Above are results for net {}".format(name)) input() ================================================ FILE: codes/config/RealESRGAN/inference.py ================================================ import argparse import logging import math import os import os.path as osp import random import sys import cv2 from collections import defaultdict from glob import glob from tqdm import tqdm import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from data.data_sampler import DistIterSampler from metrics import IQA from models import create_model #### options parser = argparse.ArgumentParser() parser.add_argument( "-opt", type=str, default="options/test/2020Track2.yml", help="Path to options YMAL file.", ) parser.add_argument("-input_dir", type=str, default="../../../data_samples/LR") parser.add_argument("-output_dir", type=str, default="../../../data_samples/BSRGAN") args = parser.parse_args() opt = option.parse(args.opt, is_train=False) opt = option.dict_to_nonedict(opt) model = create_model(opt) if not osp.exists(args.output_dir): os.makedirs(args.output_dir) test_files = glob(osp.join(args.input_dir, "*")) for inx, path in tqdm(enumerate(test_files)): name = path.split("/")[-1].split(".")[0] img = cv2.imread(path)[:, :, [2, 1, 0]] img = img.transpose(2, 0, 1)[None] / 255 img_t = torch.as_tensor(np.ascontiguousarray(img)).float() model.test({"src": img_t}) outdict = model.get_current_visuals() sr = outdict["sr"] sr_im = util.tensor2img(sr) save_path = osp.join(args.output_dir, "{}_x{}.png".format(name, opt["scale"])) cv2.imwrite(save_path, sr_im) ================================================ FILE: codes/config/RealESRGAN/models/__init__.py ================================================ import importlib import logging import os import os.path as osp from utils.registry import MODEL_REGISTRY logger = logging.getLogger("base") model_folder = osp.dirname(__file__) model_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(model_folder) if v.endswith("_model.py") ] _model_modules = [ importlib.import_module(f"models.{file_name}") for file_name in model_names ] def create_model(opt, **kwarg): model = opt["model"] m = MODEL_REGISTRY.get(model)(opt, **kwarg) logger.info("Model [{:s}] is created.".format(m.__class__.__name__)) return m ================================================ FILE: codes/config/RealESRGAN/models/base_model.py ================================================ import logging import os from collections import OrderedDict import torch import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel from archs import build_loss, build_network from utils.registry import MODEL_REGISTRY from .lr_scheduler import CosineAnnealingRestartLR, MultiStepRestartLR logger = logging.getLogger("base") @MODEL_REGISTRY.register() class BaseModel: def __init__(self, opt): self.opt = opt if opt["dist"]: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() else: self.rank = 0 # non dist training self.device = torch.device("cuda" if opt["gpu_ids"] is not None else "cpu") self.is_train = opt["is_train"] self.log_dict = OrderedDict() self.data_names = [] self.network_names = [] self.networks = {} self.optimizers = {} self.schedulers = {} def feed_data(self, data): pass def optimize_parameters(self): pass def get_current_visuals(self): pass def get_current_losses(self): pass def print_network(self): pass def save(self, label): pass def load(self): pass def build_network(self, net_opt): net = build_network(net_opt) net = self.model_to_device(net) if net_opt.get("pretrain"): pretrain = net_opt.pop("pretrain") self.load_network(net, pretrain["path"], pretrain["strict_load"]) self.print_network(net) return net def build_loss(self, loss_config): loss = build_loss(loss_config) loss = loss.to(self.device) return loss @staticmethod def build_optimizer(net, optim_config): optim_params = [] for v in net.parameters(): if v.requires_grad: optim_params.append(v) optim_type = optim_config.pop("type") optimizer = getattr(torch.optim, optim_type)( params=optim_params, **optim_config ) return optimizer def setup_schedulers(self, scheduler_opt): """Set up schedulers.""" scheduler_type = scheduler_opt.pop("type") if scheduler_type in ["MultiStepLR", "MultiStepRestartLR"]: for name, optimizer in self.optimizers.items(): self.schedulers[name] = MultiStepRestartLR(optimizer, **scheduler_opt) elif scheduler_type == "CosineAnnealingRestartLR": for name, optimizer in self.ptimizers.items(): self.schedulers[name] = CosineAnnealingRestartLR( optimizer, **scheduler_opt ) else: raise NotImplementedError( f"Scheduler {scheduler_type} is not implemented yet." ) def model_to_device(self, net): """Model to device. It also warps models with DistributedDataParallel or DataParallel. Args: net (nn.Module) """ net = net.to(self.device) if self.opt["dist"]: net = DistributedDataParallel(net, device_ids=[torch.cuda.current_device()]) else: net = DataParallel(net) return net def print_network(self, net): # Generator s, n = self.get_network_description(net) if isinstance(net, nn.DataParallel) or isinstance(net, DistributedDataParallel): net_struc_str = "{} - {}".format( net.__class__.__name__, net.module.__class__.__name__ ) else: net_struc_str = "{}".format(net.__class__.__name__) if self.rank <= 0: logger.info( "Network G structure: {}, with parameters: {:,d}".format( net_struc_str, n ) ) logger.info(s) def set_optimizer(self, names, operation): for name in names: getattr(self.optimizers[name], operation)() def set_requires_grad(self, names, requires_grad): for name in names: for v in self.networks[name].parameters(): v.requires_grad = requires_grad def set_network_state(self, names, state): for name in names: getattr(self.networks[name], state)() def clip_grad_norm(self, names, norm): for name in names: nn.utils.clip_grad_norm_(self.networks[name].parameters(), max_norm=norm) def _set_lr(self, lr_groups_l): """set learning rate for warmup, lr_groups_l: list for lr_groups. each for a optimizer""" for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): for param_group, lr in zip(optimizer.param_groups, lr_groups): param_group["lr"] = lr def _get_init_lr(self): # get the initial lr, which is set by the scheduler init_lr_groups_l = [] for optimizer in self.optimizers: init_lr_groups_l.append([v["initial_lr"] for v in optimizer.param_groups]) return init_lr_groups_l def update_learning_rate(self, cur_iter, warmup_iter=-1): for _, scheduler in self.schedulers.items(): scheduler.step() #### set up warm up learning rate if cur_iter < warmup_iter: # get initial lr for each group init_lr_g_l = self._get_init_lr() # modify warming-up learning rates warm_up_lr_l = [] for init_lr_g in init_lr_g_l: warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) # set learning rate self._set_lr(warm_up_lr_l) def get_current_learning_rate(self): # return self.schedulers[0].get_lr()[0] return list(self.optimizers.values())[0].param_groups[0]["lr"] def get_network_description(self, network): """Get the string and total parameters of the network""" if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module s = str(network) n = sum(map(lambda x: x.numel(), network.parameters())) return s, n def save_network(self, network, network_label, iter_label): save_filename = "{}_{}.pth".format(iter_label, network_label) save_path = os.path.join(self.opt["path"]["models"], save_filename) if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module state_dict = network.state_dict() for key, param in state_dict.items(): state_dict[key] = param.cpu() torch.save(state_dict, save_path) def save(self, iter_label): for name in self.optimizers.keys(): self.save_network(self.networks[name], name, iter_label) def load_network(self, network, load_path, strict=True): if load_path is not None: if isinstance(network, nn.DataParallel) or isinstance( network, DistributedDataParallel ): network = network.module load_net = torch.load(load_path) load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith("module."): load_net_clean[k[7:]] = v else: load_net_clean[k] = v network.load_state_dict(load_net_clean, strict=strict) def save_training_state(self, epoch, iter_step): """Saves training state during training, which will be used for resuming""" state = {"epoch": epoch, "iter": iter_step, "schedulers": {}, "optimizers": {}} for k, s in self.schedulers.items(): state["schedulers"][k] = s.state_dict() for k, o in self.optimizers.items(): state["optimizers"][k] = o.state_dict() save_filename = "{}.state".format(iter_step) save_path = os.path.join(self.opt["path"]["training_state"], save_filename) torch.save(state, save_path) def resume_training(self, resume_state): """Resume the optimizers and schedulers for training""" resume_optimizers = resume_state["optimizers"] resume_schedulers = resume_state["schedulers"] assert len(resume_optimizers) == len( self.optimizers ), "Wrong lengths of optimizers" assert len(resume_schedulers) == len( self.schedulers ), "Wrong lengths of schedulers" for name, o in resume_optimizers.items(): self.optimizers[name].load_state_dict(o) for name, s in resume_schedulers.items(): self.schedulers[name].load_state_dict(s) def reduce_loss_dict(self, loss_dict): """reduce loss dict. In distributed training, it averages the losses among different GPUs . Args: loss_dict (OrderedDict): Loss dict. """ with torch.no_grad(): if self.opt["dist"]: keys = [] losses = [] for name, value in loss_dict.items(): keys.append(name) losses.append(value) losses = torch.stack(losses, 0) torch.distributed.reduce(losses, dst=0) if self.rank == 0: losses /= self.world_size loss_dict = {key: loss for key, loss in zip(keys, losses)} log_dict = OrderedDict() for name, value in loss_dict.items(): log_dict[name] = value.mean().item() return log_dict def get_current_log(self): return self.log_dict ================================================ FILE: codes/config/RealESRGAN/models/sr_model.py ================================================ import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class SRModel(BaseModel): def __init__(self, opt): super().__init__(opt) self.data_names = ["lr", "hr"] self.network_names = ["netSR"] self.networks = {} self.loss_names = ["sr_adv", "sr_pix", "sr_percep"] self.loss_weights = {} self.losses = {}import logging from collections import OrderedDict import torch import torch.nn as nn from utils.registry import MODEL_REGISTRY from .base_model import BaseModel logger = logging.getLogger("base") @MODEL_REGISTRY.register() class SRModel(BaseModel): def __init__(self, opt): super().__init__(opt) self.data_names = ["lr", "hr"] self.network_names = ["netSR"] self.networks = {} self.loss_names = ["sr_adv", "sr_pix", "sr_percep"] self.loss_weights = {} self.losses = {} self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: # setup loss, optimizers, schedulers self.setup_train(opt["train"]) def feed_data(self, data): self.lr = data["src"].to(self.device) self.hr = data["tgt"].to(self.device) def forward(self): self.sr = self.netSR(self.lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() l_sr = 0 sr_pix = self.losses["sr_pix"](self.hr, self.sr) loss_dict["sr_pix"] = sr_pix l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], False) sr_adv_g = self.calculate_rgan_loss_G( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_g"] = sr_adv_g l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.hr, self.sr) loss_dict["sr_percep"] = sr_percep if sr_style is not None: loss_dict["sr_style"] = sr_style l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], True) sr_adv_d = self.calculate_rgan_loss_D( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_d"] = sr_adv_d self.optimizers["netD"].zero_grad() sr_adv_d.backward() self.optimizers["netD"].step() self.log_dict = self.reduce_loss_dict(loss_dict) def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data, crop_size=None): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): if crop_size is None: self.fake_real_hr = self.netSR(self.real_lr) else: self.fake_real_hr = self.crop_test(self.real_lr, crop_size) self.netSR.train() def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict self.optimizers = {} # define networks and load pretrained models nets_opt = opt["networks"] defined_network_names = list(nets_opt.keys()) assert set(defined_network_names).issubset(set(self.network_names)) for name in defined_network_names: setattr(self, name, self.build_network(nets_opt[name])) self.networks[name] = getattr(self, name) if self.is_train: train_opt = opt["train"] # define losses loss_opt = train_opt["losses"] defined_loss_names = list(loss_opt.keys()) assert set(defined_loss_names).issubset(set(self.loss_names)) for name in defined_loss_names: loss_conf = loss_opt.get(name) if loss_conf["weight"] > 0: self.loss_weights[name] = loss_conf.pop("weight") self.losses[name] = self.build_loss(loss_conf) # build optmizers optimizer_opt = train_opt["optimizers"] defined_optimizer_names = list(optimizer_opt.keys()) assert set(defined_optimizer_names).issubset(self.networks.keys()) for name in defined_optimizer_names: optim_config = optimizer_opt[name] self.optimizers[name] = self.build_optimizer( getattr(self, name), optim_config ) # set schedulers scheduler_opt = train_opt["scheduler"] self.setup_schedulers(scheduler_opt) # set to training state self.set_network_state(self.networks.keys(), "train") def feed_data(self, data): self.lr = data["src"].to(self.device) self.hr = data["tgt"].to(self.device) def forward(self): self.sr = self.netSR(self.lr) def optimize_parameters(self, step): self.forward() loss_dict = OrderedDict() l_sr = 0 sr_pix = self.losses["sr_pix"](self.hr, self.sr) loss_dict["sr_pix"] = sr_pix l_sr += self.loss_weights["sr_pix"] * sr_pix if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], False) sr_adv_g = self.calculate_rgan_loss_G( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_g"] = sr_adv_g l_sr += self.loss_weights["sr_adv"] * sr_adv_g if self.losses.get("sr_percep"): sr_percep, sr_style = self.losses["sr_percep"](self.hr, self.sr) loss_dict["sr_percep"] = sr_percep if sr_style is not None: loss_dict["sr_style"] = sr_style l_sr += self.loss_weights["sr_percep"] * sr_style l_sr += self.loss_weights["sr_percep"] * sr_percep self.set_optimizer(names=["netSR"], operation="zero_grad") l_sr.backward() self.set_optimizer(names=["netSR"], operation="step") if self.losses.get("sr_adv"): self.set_requires_grad(["netD"], True) sr_adv_d = self.calculate_rgan_loss_D( self.netD, self.losses["sr_adv"], self.hr, self.sr ) loss_dict["sr_adv_d"] = sr_adv_d self.optimizers["netD"].zero_grad() sr_adv_d.backward() self.optimizers["netD"].step() self.log_dict = self.reduce_loss_dict(loss_dict) def calculate_rgan_loss_D(self, netD, criterion, real, fake): d_pred_fake = netD(fake.detach()) d_pred_real = netD(real) loss_real = criterion( d_pred_real - d_pred_fake.detach().mean(), True, is_disc=False ) loss_fake = criterion( d_pred_fake - d_pred_real.detach().mean(), False, is_disc=False ) loss = (loss_real + loss_fake) / 2 return loss def calculate_rgan_loss_G(self, netD, criterion, real, fake): d_pred_fake = netD(fake) d_pred_real = netD(real).detach() loss_real = criterion(d_pred_real - d_pred_fake.mean(), False, is_disc=False) loss_fake = criterion(d_pred_fake - d_pred_real.mean(), True, is_disc=False) loss = (loss_real + loss_fake) / 2 return loss def test(self, data, crop_size=None): self.real_lr = data["src"].to(self.device) self.netSR.eval() with torch.no_grad(): if crop_size is None: self.fake_real_hr = self.netSR(self.real_lr) else: self.fake_real_hr = self.crop_test(self.real_lr, crop_size) self.netSR.train() def crop_test(self, lr, crop_size): b, c, h, w = lr.shape scale = self.opt["scale"] h_start = list(range(0, h-crop_size, crop_size)) w_start = list(range(0, w-crop_size, crop_size)) sr1 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hs in h_start: for ws in w_start: lr_patch = lr[:, :, hs: hs+crop_size, ws: ws+crop_size] sr_patch = self.netSR(lr_patch) sr1[:, :, int(hs*scale):int((hs+crop_size)*scale), int(ws*scale):int((ws+crop_size)*scale) ] = sr_patch h_end = list(range(h, crop_size, -crop_size)) w_end = list(range(w, crop_size, -crop_size)) sr2 = torch.zeros(b, c, int(h*scale), int(w* scale), device=self.device) - 1 for hd in h_end: for wd in w_end: lr_patch = lr[:, :, hd-crop_size:hd, wd-crop_size:wd] sr_patch = self.netSR(lr_patch) sr2[:, :, int((hd-crop_size)*scale):int(hd*scale), int((wd-crop_size)*scale):int(wd*scale) ] = sr_patch mask1 = ( (sr1 == -1).float() * 0 + (sr2 == -1).float() * 1 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) mask2 = ( (sr1 == -1).float() * 1 + (sr2 == -1).float() * 0 + ((sr1 > 0) * (sr2 > 0)).float() * 0.5 ) sr = mask1 * sr1 + mask2 * sr2 return sr def get_current_visuals(self, need_GT=True): out_dict = OrderedDict() out_dict["lr"] = self.real_lr.detach()[0].float().cpu() out_dict["sr"] = self.fake_real_hr.detach()[0].float().cpu() return out_dict ================================================ FILE: codes/config/RealESRGAN/options/test/2017Track2_2020Track1.yml ================================================ #### general settings name: 2017Track2_2020Track1 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [6] metrics: [psnr, ssim, lpips, niqe, piqe, brisque] datasets: test1: name: 2017Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2017/valid_LR/x4.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2020Track1 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2020/track1/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/RealESRGAN/RealESRGANx4.pth strict_load: true ================================================ FILE: codes/config/RealESRGAN/options/test/2018Track2_2018Track4.yml ================================================ #### general settings name: 2018Track2_2018Track4 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [6] metrics: [best_psnr, best_ssim, best_lpips, niqe, piqe, brisque] datasets: test1: name: 2018Track2 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track2/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb test2: name: 2018Track4 mode: PairedDataset data_type: lmdb dataroot_src: /home/lzx/SRDatasets/NTIRE2018/track4/valid.lmdb dataroot_tgt: /home/lzx/SRDatasets/DIV2K_valid/HR/x4.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/RealESRGAN/RealESRGANx4.pth strict_load: true ================================================ FILE: codes/config/RealESRGAN/options/test/2020Track2.yml ================================================ #### general settings name: 2020Track2 use_tb_logger: false model: SRModel scale: 4 gpu_ids: [5] metrics: [niqe, piqe, brisque] datasets: test1: name: 2020Track2 mode: SingleDataset data_type: lmdb dataroot: /home/lzx/SRDatasets/NTIRE2020/track2/test.lmdb #### network structures networks: netSR: which_network: RRDBNet setting: in_nc: 3 out_nc: 3 nf: 64 nb: 23 gc: 32 upscale: 4 pretrain: path: ../../../checkpoints/RealESRGAN/RealESRGANx4.pth strict_load: true ================================================ FILE: codes/config/RealESRGAN/test.py ================================================ import argparse import logging import os.path import sys import time from collections import OrderedDict, defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model from utils import bgr2ycbcr, imresize def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=False) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./result") os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 torch.backends.cudnn.benchmark = True util.setup_logger( "base", opt["path"]["log"], "test_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) logger = logging.getLogger("base") logger.info(option.dict2str(opt)) # Create test dataset and dataloader test_datasets = [] test_loaders = [] for phase, dataset_opt in sorted(opt["datasets"].items()): test_set = create_dataset(dataset_opt) test_loader = create_dataloader(test_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of test images in [{:s}]: {:d}".format( dataset_opt["name"], len(test_set) ) ) test_datasets.append(test_set) test_loaders.append(test_loader) # load pretrained model by default model = create_model(opt) for test_dataset, test_loader in zip(test_datasets, test_loaders): test_set_name = test_dataset.opt["name"] dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name) if rank == 0: logger.info("\nTesting [{:s}]...".format(test_set_name)) util.mkdir(dataset_dir) validate( model, test_dataset, test_loader, opt, measure, dataset_dir, test_set_name, logger, ) def validate( model, dataset, dist_loader, opt, measure, dataset_dir, test_set_name, logger ): test_results = {} test_results_y = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() test_results_y[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 indices = list(range(rank, len(dataset), world_size)) for ( idx, test_data, ) in enumerate(dist_loader): idx = indices[idx] img_path = test_data["src_path"][0] img_name = img_path.split("/")[-1].split(".")[0] model.test(test_data) visuals = model.get_current_visuals() sr_img = util.tensor2img(visuals["sr"]) # uint8 suffix = opt["suffix"] if suffix: save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png") else: save_img_path = os.path.join(dataset_dir, img_name + ".png") util.save_img(sr_img, save_img_path) message = "img:{:15s}; ".format(img_name) crop_border = opt["crop_border"] if opt["crop_border"] else opt["scale"] if crop_border == 0: cropped_sr_img = sr_img else: cropped_sr_img = sr_img[ crop_border:-crop_border, crop_border:-crop_border, : ] if "tgt" in test_data.keys(): gt_img = util.tensor2img(test_data["tgt"][0].double().cpu()) if crop_border == 0: cropped_gt_img = gt_img else: cropped_gt_img = gt_img[ crop_border:-crop_border, crop_border:-crop_border, : ] else: gt_img = None cropped_gt_img = None message += "Scores - " scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v message += "{}: {:.6f}; ".format(k, v) if sr_img.shape[2] == 3: # RGB image sr_img_y = bgr2ycbcr(sr_img, only_y=True) if crop_border == 0: cropped_sr_img_y = sr_img_y * 255 else: cropped_sr_img_y = ( sr_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) if gt_img is not None: gt_img_y = bgr2ycbcr(gt_img, only_y=True) if crop_border == 0: cropped_gt_img_y = gt_img_y * 255 else: cropped_gt_img_y = ( gt_img_y[crop_border:-crop_border, crop_border:-crop_border] * 255 ) else: gt_img_y = None cropped_gt_img_y = None message += "Y Scores - " scores = measure( res=cropped_sr_img_y, ref=cropped_gt_img_y, metrics=opt["metrics"] ) for k, v in scores.items(): test_results_y[k][idx] = v message += "{}: {:.6f}; ".format(k, v) logger.info(message) if opt["dist"]: for k, v in test_results.items(): dist.reduce(v, dst=0) dist.barrier() for k, v in test_results_y.items(): dist.reduce(v, dst=0) dist.barrier() # log avg_results = {} message = "Average Results for {}\n".format(test_set_name) if rank == 0: for k, v in test_results.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) avg_results_y = {} message = "Average Results on Y channel for {}\n".format(test_set_name) if rank == 0: for k, v in test_results_y.items(): avg_results[k] = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, avg_results[k]) logger.info(message) if __name__ == "__main__": main() ================================================ FILE: codes/config/RealESRGAN/train.py ================================================ import argparse import logging import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp from tensorboardX import SummaryWriter from tqdm import tqdm sys.path.append("../../") import utils as util import utils.option as option from data import create_dataloader, create_dataset from metrics import IQA from models import create_model def parse_args(): parser = argparse.ArgumentParser(description="Train keypoints network") # general parser.add_argument( "--opt", help="experiment configure file name", required=True, type=str ) parser.add_argument( "--root_path", help="experiment configure file name", default="../../../", type=str, ) # distributed training parser.add_argument("--gpu", help="gpu id for multiprocessing training", type=str) parser.add_argument( "--world-size", default=1, type=int, help="number of nodes for distributed training", ) parser.add_argument( "--dist-url", default="tcp://127.0.0.1:23456", type=str, help="url used to set up distributed training", ) parser.add_argument( "--rank", default=0, type=int, help="node rank for distributed training" ) args = parser.parse_args() return args def setup_dataloaer(opt, logger): if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 for phase, dataset_opt in opt["datasets"].items(): if phase == "train": train_set = create_dataset(dataset_opt) train_loader = create_dataloader(train_set, dataset_opt, opt["dist"]) total_iters = opt["train"]["niter"] total_epochs = total_iters // (len(train_loader) - 1) + 1 if rank == 0: logger.info( "Number of train images: {:,d}, iters: {:,d}".format( len(train_set), len(train_loader) ) ) logger.info( "Total epochs needed: {:d} for iters {:,d}".format( total_epochs, opt["train"]["niter"] ) ) elif phase == "val": val_set = create_dataset(dataset_opt) val_loader = create_dataloader(val_set, dataset_opt, opt["dist"]) if rank == 0: logger.info( "Number of val images in [{:s}]: {:d}".format( dataset_opt["name"], len(val_set) ) ) else: raise NotImplementedError("Phase [{:s}] is not recognized.".format(phase)) assert train_loader is not None assert val_loader is not None return train_set, train_loader, val_set, val_loader, total_iters, total_epochs def main(): args = parse_args() opt = option.parse(args.opt, args.root_path, is_train=True) # convert to NoneDict, which returns None for missing keys opt = option.dict_to_nonedict(opt) if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) ngpus_per_node = torch.cuda.device_count() args.world_size = ngpus_per_node * args.world_size opt["dist"] = args.world_size > 1 if opt["train"].get("resume_state", None) is None: util.mkdir_and_rename( opt["path"]["experiments_root"] ) # rename experiment folder if exists util.mkdirs( (path for key, path in opt["path"].items() if not key == "experiments_root") ) os.system("rm ./log") os.symlink(os.path.join(opt["path"]["experiments_root"], ".."), "./log") if opt["dist"]: mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt, args)) else: main_worker(0, 1, opt, args) def main_worker(gpu, ngpus_per_node, opt, args): if opt["dist"]: if args.dist_url == "env://" and args.rank == -1: rank = int(os.environ["RANK"]) rank = args.rank * ngpus_per_node + gpu print( f"Init process group: dist_url: \ {args.dist_url}, world_size: {args.world_size}, rank: {rank}" ) dist.init_process_group( backend="nccl", init_method=args.dist_url, world_size=args.world_size, rank=rank, ) torch.cuda.set_device(gpu) else: rank = 0 seed = opt["train"]["manual_seed"] if seed is None: util.set_random_seed(rank) torch.backends.cudnn.benchmark = True # torch.backends.cudnn.deterministic = True # setup tensorboard and val logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger = SummaryWriter(log_dir="log/{}/tb_logger/".format(opt["name"])) util.setup_logger( "val", opt["path"]["log"], "val_" + opt["name"], level=logging.INFO, screen=True, tofile=True, ) measure = IQA(metrics=opt["metrics"], cuda=True) # config loggers. Before it, the log will not work util.setup_logger( "base", opt["path"]["log"], "train_" + opt["name"] + "_rank{}".format(rank), level=logging.INFO if rank == 0 else logging.ERROR, screen=True, tofile=True, ) logger = logging.getLogger("base") if rank == 0: logger.info(option.dict2str(opt)) # create dataset ( train_set, train_loader, val_set, val_loader, total_iters, total_epochs, ) = setup_dataloaer(opt, logger) # create model model = create_model(opt) # loading resume state if exists if opt["train"].get("resume_state", None): # distributed resuming: all load into default GPU device_id = gpu resume_state = torch.load( opt["train"]["resume_state"], map_location=lambda storage, loc: storage.cuda(device_id), ) logger.info( "Resuming training from epoch: {}, iter: {}.".format( resume_state["epoch"], resume_state["iter"] ) ) start_epoch = resume_state["epoch"] current_step = resume_state["iter"] model.resume_training(resume_state) # handle optimizers and schedulers else: current_step = 0 start_epoch = 0 logger.info( "Start training from epoch: {:d}, iter: {:d}".format(start_epoch, current_step) ) data_time, iter_time = time.time(), time.time() avg_data_time = avg_iter_time = 0 count = 0 for epoch in range(start_epoch, total_epochs + 1): for _, train_data in enumerate(train_loader): current_step += 1 count += 1 if current_step > total_iters: break data_time = time.time() - data_time avg_data_time = (avg_data_time * (count - 1) + data_time) / count model.feed_data(train_data) model.optimize_parameters(current_step) model.update_learning_rate( current_step, warmup_iter=opt["train"]["warmup_iter"] ) iter_time = time.time() - iter_time avg_iter_time = (avg_iter_time * (count - 1) + iter_time) / count # log if current_step % opt["logger"]["print_freq"] == 0: logs = model.get_current_log() message = ( f" " ) message += f'[time (data): {avg_iter_time:.3f} ({avg_data_time:.3f})] ' for k, v in logs.items(): message += "{:s}: {:.4e}; ".format(k, v) # tensorboard logger if opt["use_tb_logger"] and "debug" not in opt["name"]: if rank == 0: tb_logger.add_scalar(k, v, current_step) logger.info(message) # validation if current_step % opt["train"]["val_freq"] == 0: avg_results = validate( model, val_set, val_loader, opt, measure, epoch, current_step ) # tensorboard logger if rank == 0: if opt["use_tb_logger"] and "debug" not in opt["name"]: for k, v in avg_results.items(): tb_logger.add_scalar(k, v, current_step) # save models and training states if current_step % opt["logger"]["save_checkpoint_freq"] == 0: if rank == 0: logger.info("Saving models and training states.") model.save(current_step) model.save_training_state(epoch, current_step) data_time = time.time() iter_time = time.time() if rank == 0: logger.info("Saving the final model.") model.save("latest") logger.info("End of training.") if opt["use_tb_logger"] and "debug" not in opt["name"]: tb_logger.close() def validate(model, dataset, dist_loader, opt, measure, epoch, current_step): test_results = {} for metric in opt["metrics"]: test_results[metric] = torch.zeros((len(dataset))).cuda() if opt["dist"]: rank = dist.get_rank() world_size = dist.get_world_size() else: world_size = 1 rank = 0 if rank == 0: pbar = tqdm(total=len(dataset), leave=False, dynamic_ncols=True) indices = list(range(rank, len(dataset), world_size)) for ( idx, val_data, ) in enumerate(dist_loader): idx = indices[idx] LR_img = val_data["src"] lr_img = util.tensor2img(LR_img) # save LR image for reference model.test(val_data) visuals = model.get_current_visuals() # Save images for reference img_name = val_data["src_path"][0].split("/")[-1].split(".")[0] img_dir = os.path.join(opt["path"]["val_images"], img_name) util.mkdir(img_dir) save_lr_path = os.path.join(img_dir, "{:s}_LR.png".format(img_name)) util.save_img(lr_img, save_lr_path) sr_img = util.tensor2img(visuals["sr"]) # uint8 save_img_path = os.path.join( img_dir, "{:s}_{:d}.png".format(img_name, current_step) ) util.save_img(sr_img, save_img_path) if "fake_lr" in visuals.keys(): fake_lr_img = util.tensor2img(visuals["fake_lr"]) save_img_path = os.path.join( img_dir, f"fake_lr_{current_step:d}.png" ) util.save_img(fake_lr_img, save_img_path) # calculate scores crop_size = opt["scale"] cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] if "tgt" in val_data.keys(): gt_img = util.tensor2img(val_data["tgt"]) cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] else: cropped_gt_img = gt_img = None scores = measure(res=cropped_sr_img, ref=cropped_gt_img, metrics=opt["metrics"]) for k, v in scores.items(): test_results[k][idx] = v if rank == 0: for _ in range(world_size): pbar.update(1) if rank == 0: pbar.close() # log avg_results = {} message = " thres_sz: h_space = np.append(h_space, h - crop_sz) w_space = np.arange(0, w - crop_sz + 1, step) if w - (w_space[-1] + crop_sz) > thres_sz: w_space = np.append(w_space, w - crop_sz) index = 0 for x in h_space: for y in w_space: index += 1 if n_channels == 2: crop_img = img[x : x + crop_sz, y : y + crop_sz] else: crop_img = img[x : x + crop_sz, y : y + crop_sz, :] crop_img = np.ascontiguousarray(crop_img) # var = np.var(crop_img / 255) # if var > 0.008: # print(img_name, index_str, var) cv2.imwrite( os.path.join( save_folder, img_name.replace(".png", "_s{:03d}.png".format(index)) ), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level], ) return "Processing {:s} ...".format(img_name) if __name__ == "__main__": main() ================================================ FILE: codes/scripts/generate_mod_LR_bic.m ================================================ function generate_mod_LR_bic() %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. %% set parameters % comment the unnecessary line input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub'; % save_mod_folder = ''; save_LR_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4'; % save_bic_folder = ''; up_scale = 4; mod_scale = 4; if exist('save_mod_folder', 'var') if exist(save_mod_folder, 'dir') disp(['It will cover ', save_mod_folder]); else mkdir(save_mod_folder); end end if exist('save_LR_folder', 'var') if exist(save_LR_folder, 'dir') disp(['It will cover ', save_LR_folder]); else mkdir(save_LR_folder); end end if exist('save_bic_folder', 'var') if exist(save_bic_folder, 'dir') disp(['It will cover ', save_bic_folder]); else mkdir(save_bic_folder); end end idx = 0; filepaths = dir(fullfile(input_folder,'*.*')); for i = 1 : length(filepaths) [paths,imname,ext] = fileparts(filepaths(i).name); if isempty(imname) disp('Ignore . folder.'); elseif strcmp(imname, '.') disp('Ignore .. folder.'); else idx = idx + 1; str_rlt = sprintf('%d\t%s.\n', idx, imname); fprintf(str_rlt); % read image img = imread(fullfile(input_folder, [imname, ext])); img = im2double(img); % modcrop img = modcrop(img, mod_scale); if exist('save_mod_folder', 'var') imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); end % LR im_LR = imresize(img, 1/up_scale, 'bicubic'); if exist('save_LR_folder', 'var') imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png'])); end % Bicubic if exist('save_bic_folder', 'var') im_B = imresize(im_LR, up_scale, 'bicubic'); imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png'])); end end end end %% modcrop function img = modcrop(img, modulo) if size(img,3) == 1 sz = size(img); sz = sz - mod(sz, modulo); img = img(1:sz(1), 1:sz(2)); else tmpsz = size(img); sz = tmpsz(1:2); sz = sz - mod(sz, modulo); img = img(1:sz(1), 1:sz(2),:); end end ================================================ FILE: codes/scripts/generate_mod_LR_bic.py ================================================ import os import sys import cv2 import numpy as np try: sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from utils import imresize except ImportError: pass def generate_mod_LR_bic(): # set parameters up_scale = 2 mod_scale = 2 # set data dir sourcedir = "/mnt/hdd/lzx/SRDatasets/NTIRE2018/DIV2K_valid_HR/" savedir = "/mnt/hdd/lzx/SRDatasets/DIV2K_valid/" saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale)) saveLRpath = os.path.join(savedir, "BicLR", "x" + str(up_scale)) # saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale)) if not os.path.isdir(sourcedir): print("Error: No source data found") exit(0) if not os.path.isdir(savedir): os.mkdir(savedir) if not os.path.isdir(os.path.join(savedir, "HR")): os.mkdir(os.path.join(savedir, "HR")) if not os.path.isdir(os.path.join(savedir, "BicLR")): os.mkdir(os.path.join(savedir, "BicLR")) # if not os.path.isdir(os.path.join(savedir, "Bic")): # os.mkdir(os.path.join(savedir, "Bic")) if not os.path.isdir(saveHRpath): os.mkdir(saveHRpath) else: print("It will cover " + str(saveHRpath)) if not os.path.isdir(saveLRpath): os.mkdir(saveLRpath) else: print("It will cover " + str(saveLRpath)) # if not os.path.isdir(saveBicpath): # os.mkdir(saveBicpath) # else: # print("It will cover " + str(saveBicpath)) filepaths = [f for f in os.listdir(sourcedir) if f.endswith(".png")] num_files = len(filepaths) # prepare data with augementation for i in range(num_files): filename = filepaths[i] print("No.{} -- Processing {}".format(i, filename)) # read image image = cv2.imread(os.path.join(sourcedir, filename)) width = int(np.floor(image.shape[1] / mod_scale)) height = int(np.floor(image.shape[0] / mod_scale)) # modcrop if len(image.shape) == 3: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width, :] else: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width] # LR image_LR = imresize(image_HR, 1 / up_scale, True) # bic # image_Bic = imresize(image_LR, up_scale, True) cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) # cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) if __name__ == "__main__": generate_mod_LR_bic() ================================================ FILE: codes/scripts/generate_mod_blur_LR_bic.py ================================================ import os import sys import cv2 import numpy as np import torch try: sys.path.append("..") from utils import imresize import utils as util except ImportError: pass def generate_mod_LR_bic(): # set parameters up_scale = 4 mod_scale = 4 # set data dir sourcedir = "/data/Set5/source/" savedir = "/data/Set5/" # load PCA matrix of enough kernel print("load PCA matrix") pca_matrix = torch.load( "../../pca_matrix.pth", map_location=lambda storage, loc: storage ) print("PCA matrix shape: {}".format(pca_matrix.shape)) degradation_setting = { "random_kernel": False, "code_length": 10, "ksize": 21, "pca_matrix": pca_matrix, "scale": up_scale, "cuda": True, "rate_iso": 1.0, } # set random seed util.set_random_seed(0) saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale)) saveLRpath = os.path.join(savedir, "LR", "x" + str(up_scale)) saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale)) saveLRblurpath = os.path.join(savedir, "LRblur", "x" + str(up_scale)) if not os.path.isdir(sourcedir): print("Error: No source data found") exit(0) if not os.path.isdir(savedir): os.mkdir(savedir) if not os.path.isdir(os.path.join(savedir, "HR")): os.mkdir(os.path.join(savedir, "HR")) if not os.path.isdir(os.path.join(savedir, "LR")): os.mkdir(os.path.join(savedir, "LR")) if not os.path.isdir(os.path.join(savedir, "Bic")): os.mkdir(os.path.join(savedir, "Bic")) if not os.path.isdir(os.path.join(savedir, "LRblur")): os.mkdir(os.path.join(savedir, "LRblur")) if not os.path.isdir(saveHRpath): os.mkdir(saveHRpath) else: print("It will cover " + str(saveHRpath)) if not os.path.isdir(saveLRpath): os.mkdir(saveLRpath) else: print("It will cover " + str(saveLRpath)) if not os.path.isdir(saveBicpath): os.mkdir(saveBicpath) else: print("It will cover " + str(saveBicpath)) if not os.path.isdir(saveLRblurpath): os.mkdir(saveLRblurpath) else: print("It will cover " + str(saveLRblurpath)) filepaths = sorted([f for f in os.listdir(sourcedir) if f.endswith(".png")]) print(filepaths) num_files = len(filepaths) # kernel_map_tensor = torch.zeros((num_files, 1, 10)) # each kernel map: 1*10 # prepare data with augementation for i in range(num_files): filename = filepaths[i] print("No.{} -- Processing {}".format(i, filename)) # read image image = cv2.imread(os.path.join(sourcedir, filename)) width = int(np.floor(image.shape[1] / mod_scale)) height = int(np.floor(image.shape[0] / mod_scale)) # modcrop if len(image.shape) == 3: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width, :] else: image_HR = image[0 : mod_scale * height, 0 : mod_scale * width] # LR_blur, by random gaussian kernel img_HR = util.img2tensor(image_HR) C, H, W = img_HR.size() for sig in np.linspace(1.8, 3.2, 8): prepro = util.SRMDPreprocessing(sig=sig, **degradation_setting) LR_img, ker_map = prepro(img_HR.view(1, C, H, W)) image_LR_blur = util.tensor2img(LR_img) cv2.imwrite( os.path.join(saveLRblurpath, "sig{}_{}".format(sig, filename)), image_LR_blur, ) cv2.imwrite( os.path.join(saveHRpath, "sig{}_{}".format(sig, filename)), image_HR ) # LR image_LR = imresize(image_HR, 1 / up_scale, True) # bic image_Bic = imresize(image_LR, up_scale, True) # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) # kernel_map_tensor[i] = ker_map # save dataset corresponding kernel maps # torch.save(kernel_map_tensor, './Set5_sig2.6_kermap.pth') print("Image Blurring & Down smaple Done: X" + str(up_scale)) if __name__ == "__main__": generate_mod_LR_bic() ================================================ FILE: codes/scripts/test_imgs.py ================================================ import argparse import glob import importlib as imp import os import os.path as osp import sys from collections import defaultdict import cv2 import numpy as np sys.path.append("../") from metrics.measure import IQA def parse_argumnets(): parser = argparse.ArgumentParser() parser.add_argument( "--res_dir", type=str, default=None, help="directory of test images" ) parser.add_argument( "--ref_dir", type=str, default=None, help="directory of reference images" ) parser.add_argument( "--save_dir", type=str, default=None, help="directory of saved results" ) parser.add_argument("--metrics", type=list, default=["psnr", "ssim", "lpips", "niqe", "piqe", "brisque"]) args = parser.parse_args() return args def bgr2ycbcr(img, only_y=True): """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: rlt = ( np.matmul( img, [ [24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0], ], ) / 255.0 + [16, 128, 128] ) if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def main(): args = parse_argumnets() if args.save_dir is None: args.save_dir = args.res_dir if args.res_dir is None: raise TypeError("res dir can not be None") if not osp.exists(args.res_dir): raise ValueError("res dir dose not exist") res_paths = sorted(glob.glob(osp.join(args.res_dir, "*.png"))) print(f"{len(res_paths)} images to be tested") if args.ref_dir is not None: ref_paths = sorted(glob.glob(osp.join(args.ref_dir, "*.png"))) if not len(res_paths) == len(ref_paths): raise ValueError( f"Number of res images {len(res_paths)} must be equal\ to Number of ref images {len(ref_paths)}" ) score_file_name = "_".join(osp.abspath(args.res_dir).split("/")) score_file_name = osp.join(args.save_dir, f"{score_file_name}.txt") score_file = open(score_file_name, "w") measure = IQA(metrics=args.metrics, cuda=False) test_results_rgb = defaultdict(list) test_results_y = defaultdict(list) for indx, res_path in enumerate(res_paths): res_img = cv2.imread(res_path) message = f"image {res_path}\t" if args.ref_dir is not None: ref_img = cv2.imread(ref_paths[indx]) else: ref_img = None message += "Original Scores\t" scores = measure(res=res_img, ref=ref_img, metrics=args.metrics) for k, v in scores.items(): test_results_rgb[k].append(v) message += "{}: {:.6f}; ".format(k, v) if res_img.ndim == 3: res_img_y = bgr2ycbcr(res_img, only_y=True) if ref_img is not None: ref_img_y = bgr2ycbcr(ref_img, only_y=True) else: ref_img_y = None message += "Y Scores\t" scores = measure(res=res_img_y, ref=ref_img_y, metrics=args.metrics) for k, v in scores.items(): test_results_y[k].append(v) message += "{}: {:.6f}; ".format(k, v) print(message) score_file.write(message + "\n") message = "-" * 10 + "Average Results" + "-" * 10 + "\n" message += "Origianl Scores\t" for k, v in test_results_rgb.items(): ave = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, ave) if len(test_results_y) > 0: message += "Y Scores\t" for k, v in test_results_y.items(): ave = sum(v) / len(v) message += "{}: {:.6f}; ".format(k, ave) print(message) score_file.write(message) score_file.close() if __name__ == "__main__": main() ================================================ FILE: codes/utils/__init__.py ================================================ import importlib import os import os.path as osp utils_folder = osp.dirname(__file__) utils_names = [ osp.splitext(osp.basename(v))[0] for v in os.listdir(utils_folder) if v.endswith("_utils.py") ] for file_name in utils_names: exec(f"from .{file_name} import *") ================================================ FILE: codes/utils/data_utils.py ================================================ import math import os import pickle import random import cv2 import numpy as np import torch # Files & IO IMG_EXTENSIONS = [ ".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP", ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def _get_paths_from_images(path): """get image path list from image folder""" assert os.path.isdir(path), "{:s} is not a valid directory".format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): if is_image_file(fname): img_path = os.path.join(dirpath, fname) images.append(img_path) assert images, "{:s} has no valid image file".format(path) return images def _get_paths_from_lmdb(dataroot): """get image path list from lmdb meta info""" meta_info = pickle.load(open(os.path.join(dataroot, "meta_info.pkl"), "rb")) paths = meta_info["keys"] sizes = meta_info["resolution"] if len(sizes) == 1: sizes = sizes * len(paths) return paths, sizes def get_image_paths(data_type, dataroot): """get image path list support lmdb or image files""" paths, sizes = None, None if dataroot is None: return None, None else: if data_type == "lmdb": paths, sizes = _get_paths_from_lmdb(dataroot) return paths, sizes elif data_type == "img": paths = sorted(_get_paths_from_images(dataroot)) return paths, None else: raise NotImplementedError( "data_type [{:s}] is not recognized.".format(data_type) ) def _read_img_lmdb(env, key, size): """read image from lmdb with key (w/ and w/o fixed size) size: (C, H, W) tuple""" with env.begin(write=False) as txn: buf = txn.get(key.encode("ascii")) img_flat = np.frombuffer(buf, dtype=np.uint8) C, H, W = size img = img_flat.reshape(H, W, C) return img def read_img(env, path, size=None): """read image by cv2 or from lmdb return: Numpy float32, HWC, BGR, [0,1]""" if env is None: # img img = cv2.imread(path, cv2.IMREAD_UNCHANGED) else: img = _read_img_lmdb(env, path, size) img = img.astype(np.float32) / 255.0 if img.ndim == 2: img = np.expand_dims(img, axis=2) # some images have 4 channels if img.shape[2] > 3: img = img[:, :, :3] return img # image processing # process on numpy image def augment(img, hflip=True, rot=True, mode=None): # horizontal flip OR rotate hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 def _augment(img): if hflip: img = img[:, ::-1, :] if vflip: img = img[::-1, :, :] if rot90: img = img.transpose(1, 0, 2) return img if len(img) == 1: return _augment(img[0]) else: return [_augment(I) for I in img] def augment_flow(img_list, flow_list, hflip=True, rot=True): # horizontal flip OR rotate hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 def _augment(img): if hflip: img = img[:, ::-1, :] if vflip: img = img[::-1, :, :] if rot90: img = img.transpose(1, 0, 2) return img def _augment_flow(flow): if hflip: flow = flow[:, ::-1, :] flow[:, :, 0] *= -1 if vflip: flow = flow[::-1, :, :] flow[:, :, 1] *= -1 if rot90: flow = flow.transpose(1, 0, 2) flow = flow[:, :, [1, 0]] return flow rlt_img_list = [_augment(img) for img in img_list] rlt_flow_list = [_augment_flow(flow) for flow in flow_list] return rlt_img_list, rlt_flow_list ================================================ FILE: codes/utils/deg_utils.py ================================================ import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from scipy.io import loadmat from .resize_utils import imresize def DUF_downsample(x, scale=4): """Downsamping with Gaussian kernel used in the DUF official code Args: x (Tensor, [B, T, C, H, W]): frames to be downsampled. scale (int): downsampling factor: 2 | 3 | 4. """ assert scale in [2, 3, 4], "Scale [{}] is not supported".format(scale) def gkern(kernlen=13, nsig=1.6): import scipy.ndimage.filters as fi inp = np.zeros((kernlen, kernlen)) # set element at the middle to one, a dirac delta inp[kernlen // 2, kernlen // 2] = 1 # gaussian-smooth the dirac, resulting in a gaussian filter mask return fi.gaussian_filter(inp, nsig) B, T, C, H, W = x.size() x = x.view(-1, 1, H, W) pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter r_h, r_w = 0, 0 if scale == 3: r_h = 3 - (H % 3) r_w = 3 - (W % 3) x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], "reflect") gaussian_filter = ( torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) ) x = F.conv2d(x, gaussian_filter, stride=scale) x = x[:, :, 2:-2, 2:-2] x = x.view(B, T, C, x.size(2), x.size(3)) return x def PCA(data, k=2): X = torch.from_numpy(data) X_mean = torch.mean(X, 0) X = X - X_mean.expand_as(X) U, S, V = torch.svd(torch.t(X)) return U[:, :k] # PCA matrix def random_batch_kernel( batch, l=21, sig_min=0.2, sig_max=4.0, rate_iso=1.0, tensor=True, random_disturb=False, ): if rate_iso == 1: sigma = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) ax = np.arange(-l // 2 + 1.0, l // 2 + 1.0) xx, yy = np.meshgrid(ax, ax) xx = xx[None].repeat(batch, 0) yy = yy[None].repeat(batch, 0) kernel = np.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma ** 2)) kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) return torch.FloatTensor(kernel) if tensor else kernel else: sigma_x = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) sigma_y = np.random.uniform(sig_min, sig_max, (batch, 1, 1)) D = np.zeros((batch, 2, 2)) D[:, 0, 0] = sigma_x.squeeze() ** 2 D[:, 1, 1] = sigma_y.squeeze() ** 2 radians = np.random.uniform(-np.pi, np.pi, (batch)) mask_iso = np.random.uniform(0, 1, (batch)) < rate_iso radians[mask_iso] = 0 sigma_y[mask_iso] = sigma_x[mask_iso] U = np.zeros((batch, 2, 2)) U[:, 0, 0] = np.cos(radians) U[:, 0, 1] = -np.sin(radians) U[:, 1, 0] = np.sin(radians) U[:, 1, 1] = np.cos(radians) sigma = np.matmul(U, np.matmul(D, U.transpose(0, 2, 1))) ax = np.arange(-l // 2 + 1.0, l // 2 + 1.0) xx, yy = np.meshgrid(ax, ax) xy = np.hstack((xx.reshape((l * l, 1)), yy.reshape(l * l, 1))).reshape(l, l, 2) xy = xy[None].repeat(batch, 0) inverse_sigma = np.linalg.inv(sigma)[:, None, None] kernel = np.exp( -0.5 * np.matmul( np.matmul(xy[:, :, :, None], inverse_sigma), xy[:, :, :, :, None] ) ) kernel = kernel.reshape(batch, l, l) if random_disturb: kernel = kernel + np.random.uniform(0, 0.25, (batch, l, l)) * kernel kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) return torch.FloatTensor(kernel) if tensor else kernel def stable_batch_kernel(batch, l=21, sig=2.6, tensor=True): sigma = sig ax = np.arange(-l // 2 + 1.0, l // 2 + 1.0) xx, yy = np.meshgrid(ax, ax) xx = xx[None].repeat(batch, 0) yy = yy[None].repeat(batch, 0) kernel = np.exp(-(xx ** 2 + yy ** 2) / (2.0 * sigma ** 2)) kernel = kernel / np.sum(kernel, (1, 2), keepdims=True) return torch.FloatTensor(kernel) if tensor else kernel def b_Bicubic(variable, scale): B, C, H, W = variable.size() H_new = int(H / scale) W_new = int(W / scale) tensor_v = variable.view((B, C, H, W)) re_tensor = imresize(tensor_v, 1 / scale) return re_tensor def random_batch_noise(batch, high, rate_cln=1.0): noise_level = np.random.uniform(size=(batch, 1)) * high noise_mask = np.random.uniform(size=(batch, 1)) noise_mask[noise_mask < rate_cln] = 0 noise_mask[noise_mask >= rate_cln] = 1 return noise_level * noise_mask def b_GaussianNoising(tensor, sigma, mean=0.0, noise_size=None, min=0.0, max=1.0): if noise_size is None: size = tensor.size() else: size = noise_size noise = torch.mul( torch.FloatTensor(np.random.normal(loc=mean, scale=1.0, size=size)), sigma.view(sigma.size() + (1, 1)), ).to(tensor.device) return torch.clamp(noise + tensor, min=min, max=max) def b_GaussianNoising(tensor, noise_high, mean=0.0, noise_size=None, min=0.0, max=1.0): if noise_size is None: size = tensor.size() else: size = noise_size noise = torch.FloatTensor( np.random.normal(loc=mean, scale=noise_high, size=size) ).to(tensor.device) return torch.clamp(noise + tensor, min=min, max=max) class BatchSRKernel(object): def __init__( self, l=21, sig=2.6, sig_min=0.2, sig_max=4.0, rate_iso=1.0, random_disturb=False, ): self.l = l self.sig = sig self.sig_min = sig_min self.sig_max = sig_max self.rate = rate_iso self.random_disturb = random_disturb def __call__(self, random, batch, tensor=False): if random == True: # random kernel return random_batch_kernel( batch, l=self.l, sig_min=self.sig_min, sig_max=self.sig_max, rate_iso=self.rate, tensor=tensor, random_disturb=self.random_disturb, ) else: # stable kernel return stable_batch_kernel(batch, l=self.l, sig=self.sig, tensor=tensor) class BatchBlurKernel(object): def __init__(self, kernels_path): kernels = loadmat(kernels_path)["kernels"] self.num_kernels = kernels.shape[0] self.kernels = kernels def __call__(self, random, batch, tensor=False): index = np.random.randint(0, self.num_kernels, batch) kernels = self.kernels[index] return torch.FloatTensor(kernels).contiguous() if tensor else kernels class PCAEncoder(nn.Module): def __init__(self, weight): super().__init__() self.register_buffer("weight", weight) self.size = self.weight.size() def forward(self, batch_kernel): B, H, W = batch_kernel.size() # [B, l, l] return torch.bmm( batch_kernel.view((B, 1, H * W)), self.weight.expand((B,) + self.size) ).view((B, -1)) class BatchBlur(object): def __init__(self, l=15): self.l = l if l % 2 == 1: self.pad = (l // 2, l // 2, l // 2, l // 2) else: self.pad = (l // 2, l // 2 - 1, l // 2, l // 2 - 1) # self.pad = nn.ZeroPad2d(l // 2) def __call__(self, input, kernel): B, C, H, W = input.size() pad = F.pad(input, self.pad, mode="reflect") H_p, W_p = pad.size()[-2:] if len(kernel.size()) == 2: input_CBHW = pad.view((C * B, 1, H_p, W_p)) kernel_var = kernel.contiguous().view((1, 1, self.l, self.l)) return F.conv2d(input_CBHW, kernel_var, padding=0).view((B, C, H, W)) else: input_CBHW = pad.view((1, C * B, H_p, W_p)) kernel_var = ( kernel.contiguous() .view((B, 1, self.l, self.l)) .repeat(1, C, 1, 1) .view((B * C, 1, self.l, self.l)) ) return F.conv2d(input_CBHW, kernel_var, groups=B * C).view((B, C, H, W)) class SRMDPreprocessing(object): def __init__( self, scale, pca_matrix, ksize=21, code_length=10, random_kernel=True, noise=False, cuda=False, random_disturb=False, sig=0, sig_min=0, sig_max=0, rate_iso=1.0, rate_cln=1, noise_high=0, stored_kernel=False, pre_kernel_path=None, ): self.encoder = PCAEncoder(pca_matrix).cuda() if cuda else PCAEncoder(pca) self.kernel_gen = ( BatchSRKernel( l=ksize, sig=sig, sig_min=sig_min, sig_max=sig_max, rate_iso=rate_iso, random_disturb=random_disturb, ) if not stored_kernel else BatchBlurKernel(pre_kernel_path) ) self.blur = BatchBlur(l=ksize) self.para_in = code_length self.l = ksize self.noise = noise self.scale = scale self.cuda = cuda self.rate_cln = rate_cln self.noise_high = noise_high self.random = random_kernel def __call__(self, hr_tensor, kernel=False): # hr_tensor is tensor, not cuda tensor hr_var = ( torch.FloatTensor(hr_tensor).cuda() if self.cuda else torch.FloatTensor(hr_tensor) ) device = hr_var.device B, C, H, W = hr_var.size() b_kernels = torch.FloatTensor(self.kernel_gen(self.random, B, tensor=True)).to( device ) hr_blured_var = self.blur(hr_var, b_kernels) # B x self.para_input kernel_code = self.encoder(b_kernels) # Down sample if self.scale != 1: lr_blured_t = b_Bicubic(hr_blured_var, self.scale) else: lr_blured_t = hr_blured_var # Noisy if self.noise: Noise_level = torch.FloatTensor( random_batch_noise(B, self.noise_high, self.rate_cln) ) lr_noised_t = b_GaussianNoising(lr_blured_t, self.noise_high) else: Noise_level = torch.zeros((B, 1)) lr_noised_t = lr_blured_t Noise_level = torch.FloatTensor(Noise_level).cuda() re_code = ( torch.cat([kernel_code, Noise_level * 10], dim=1) if self.noise else kernel_code ) lr_re = torch.FloatTensor(lr_noised_t).to(device) return (lr_re, re_code, b_kernels) if kernel else (lr_re, re_code) ================================================ FILE: codes/utils/file_utils.py ================================================ import logging import math import os import random import sys import time from collections import OrderedDict from datetime import datetime from shutil import get_terminal_size import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def get_timestamp(): return datetime.now().strftime("%y%m%d-%H%M%S") def mkdir(path): if not os.path.exists(path): os.makedirs(path) def mkdirs(paths): if isinstance(paths, str): mkdir(paths) else: for path in paths: mkdir(path) def mkdir_and_rename(path): if os.path.exists(path): new_name = path + "_archived_" + get_timestamp() print("Path already exists. Rename it to [{:s}]".format(new_name)) logger = logging.getLogger("base") logger.info("Path already exists. Rename it to [{:s}]".format(new_name)) os.rename(path, new_name) os.makedirs(path) def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def setup_logger( logger_name, root, phase, level=logging.INFO, screen=False, tofile=False ): """set up logger""" lg = logging.getLogger(logger_name) formatter = logging.Formatter( "%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s", datefmt="%y-%m-%d %H:%M:%S", ) lg.setLevel(level) lg.propagate = False if tofile: log_file = os.path.join(root, phase + "_{}.log".format(get_timestamp())) fh = logging.FileHandler(log_file, mode="w") fh.setFormatter(formatter) lg.addHandler(fh) if screen: sh = logging.StreamHandler() sh.setFormatter(formatter) lg.addHandler(sh) class ProgressBar(object): """A progress bar which can print the progress modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py """ def __init__(self, task_num=0, bar_width=50, start=True): self.task_num = task_num max_bar_width = self._get_max_bar_width() self.bar_width = bar_width if bar_width <= max_bar_width else max_bar_width self.completed = 0 if start: self.start() def _get_max_bar_width(self): terminal_width, _ = get_terminal_size() max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) if max_bar_width < 10: print( "terminal width is too small ({}), please consider widen the terminal for better " "progressbar visualization".format(terminal_width) ) max_bar_width = 10 return max_bar_width def start(self): if self.task_num > 0: sys.stdout.write( "[{}] 0/{}, elapsed: 0s, ETA:\n{}\n".format( " " * self.bar_width, self.task_num, "Start..." ) ) else: sys.stdout.write("completed: 0, elapsed: 0s") sys.stdout.flush() self.start_time = time.time() def update(self, msg="In progress..."): self.completed += 1 elapsed = time.time() - self.start_time fps = self.completed / elapsed if self.task_num > 0: percentage = self.completed / float(self.task_num) eta = int(elapsed * (1 - percentage) / percentage + 0.5) mark_width = int(self.bar_width * percentage) bar_chars = ">" * mark_width + "-" * (self.bar_width - mark_width) sys.stdout.write("\033[2F") # cursor up 2 lines sys.stdout.write( "\033[J" ) # clean the output (remove extra chars since last display) sys.stdout.write( "[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n".format( bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg, ) ) else: sys.stdout.write( "completed: {}, elapsed: {}s, {:.1f} tasks/s".format( self.completed, int(elapsed + 0.5), fps ) ) sys.stdout.flush() ================================================ FILE: codes/utils/img_utils.py ================================================ import math import os import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchvision.utils import make_grid def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): """ Converts a torch Tensor into an image Numpy array Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) """ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 2: img_np = tensor.numpy() else: raise TypeError( "Only support 4D, 3D and 2D tensor. But received with dimension: {:d}".format( n_dim ) ) if out_type == np.uint8: img_np = (img_np * 255.0).round() # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. return img_np.astype(out_type) def save_img(img, img_path, mode="BGR"): cv2.imwrite(img_path, img) def img2tensor(img): """ # BGR to RGB, HWC to CHW, numpy to tensor Input: img(H, W, C), [0,255], np.uint8 (default) Output: 3D(C,H,W), RGB order, float tensor """ img = img.astype(np.float32) / 255.0 img = img[:, :, [2, 1, 0]] img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float() return img def channel_convert(tar_type, img_list): # conversion among BGR, gray and y if tar_type == "gray": # BGR to gray gray_list = [] for img in img_list: if len(img.shape) == 3: if img.shape[2] == 3: img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, None] gray_list.append(img) else: gray_list.append(img[:, :, None]) return gray_list elif tar_type == "y": y_list = [] for img in img_list: if len(img.shape) == 3: if img.shape[2] == 3: img = bgr2ycbcr(img, only_y=True)[:, :, None] y_list.append(img) else: y_list.append(img[:, :, None]) return y_list elif tar_type == "RGB": rbg_list = [] for img in img_list: if len(img.shape) == 3: rbg_list.append(img) else: rbg_list.append(cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)) return rbg_list else: return img_list def rgb2ycbcr(img, only_y=True): """same as matlab rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert if only_y: rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 else: rlt = ( np.matmul( img, [ [65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214], ], ) / 255.0 + [16, 128, 128] ) if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def bgr2ycbcr(img, only_y=True): """bgr version of rgb2ycbcr only_y: only return Y channel Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert if only_y: rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 else: rlt = ( np.matmul( img, [ [24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0], ], ) / 255.0 + [16, 128, 128] ) if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def ycbcr2rgb(img): """same as matlab ycbcr2rgb Input: uint8, [0, 255] float, [0, 1] """ in_img_type = img.dtype img.astype(np.float32) if in_img_type != np.uint8: img *= 255.0 # convert rlt = ( np.matmul( img, [ [0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0], ], ) * 255.0 + [-222.921, 135.576, -276.836] ) if in_img_type == np.uint8: rlt = rlt.round() else: rlt /= 255.0 return rlt.astype(in_img_type) def modcrop(img_in, scale): # img_in: Numpy, HWC or HW img = np.copy(img_in) if img.ndim == 2: H, W = img.shape H_r, W_r = H % scale, W % scale img = img[: H - H_r, : W - W_r] elif img.ndim == 3: H, W, C = img.shape H_r, W_r = H % scale, W % scale img = img[: H - H_r, : W - W_r, :] else: raise ValueError("Wrong img ndim: [{:d}].".format(img.ndim)) return img ================================================ FILE: codes/utils/option.py ================================================ import logging import os import os.path as osp import sys from collections import OrderedDict import yaml def ordered_yaml(): """Support OrderedDict for yaml. Returns: yaml Loader and Dumper. """ try: from yaml import CDumper as Dumper from yaml import CLoader as Loader except ImportError: from yaml import Dumper, Loader _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG def dict_representer(dumper, data): return dumper.represent_dict(data.items()) def dict_constructor(loader, node): return OrderedDict(loader.construct_pairs(node)) Dumper.add_representer(OrderedDict, dict_representer) Loader.add_constructor(_mapping_tag, dict_constructor) return Loader, Dumper def parse(opt_path, root_path=".", is_train=True): opt_path = osp.abspath(opt_path) with open(opt_path, mode="r") as f: Loader, _ = ordered_yaml() opt = yaml.load(f, Loader=Loader) # export CUDA_VISIBLE_DEVICES gpu_list = ",".join(str(x) for x in opt["gpu_ids"]) os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list print("export CUDA_VISIBLE_DEVICES=" + gpu_list) opt["is_train"] = is_train # datasets for phase, dataset in opt["datasets"].items(): for p in ["train", "val", "test"]: if p in phase: dataset["phase"] = phase dataset["scale"] = opt.get("scale", 1) # path if not opt.get("path"): opt["path"] = {} opt["path"]["root"] = osp.abspath(root_path) config_paths = osp.abspath(opt_path).split("/") config_dir = config_paths[config_paths.index("config") + 1] if is_train: experiments_root = osp.join( opt["path"]["root"], "experiments", config_dir, opt["name"] ) opt["path"]["experiments_root"] = experiments_root for dirname in ["models", "training_state", "log", "val_images"]: opt["path"][dirname] = osp.join(experiments_root, dirname) # change some options for debug mode if "debug" in opt["name"]: opt["train"]["val_freq"] = 8 opt["logger"]["print_freq"] = 1 opt["logger"]["save_checkpoint_freq"] = 8 else: # test results_root = osp.join(opt["path"]["root"], "results", config_dir, opt["name"]) opt["path"]["results_root"] = results_root opt["path"]["log"] = osp.join(results_root, "log") return opt def dict2str(opt, indent_l=1): """dict to string for logger""" msg = "" for k, v in opt.items(): if isinstance(v, dict): msg += " " * (indent_l * 2) + k + ":[\n" msg += dict2str(v, indent_l + 1) msg += " " * (indent_l * 2) + "]\n" else: msg += " " * (indent_l * 2) + k + ": " + str(v) + "\n" return msg class NoneDict(dict): def __missing__(self, key): return None # convert to NoneDict, which return None for missing key. def dict_to_nonedict(opt): if isinstance(opt, dict): new_opt = dict() for key, sub_opt in opt.items(): new_opt[key] = dict_to_nonedict(sub_opt) return NoneDict(**new_opt) elif isinstance(opt, list): return [dict_to_nonedict(sub_opt) for sub_opt in opt] else: return opt ================================================ FILE: codes/utils/registry.py ================================================ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 class Registry: """ The registry that provides name -> object mapping, to support third-party users' custom modules. To create a registry (e.g. a backbone registry): .. code-block:: python BACKBONE_REGISTRY = Registry('BACKBONE') To register an object: .. code-block:: python @BACKBONE_REGISTRY.register() class MyBackbone(): ... Or: .. code-block:: python BACKBONE_REGISTRY.register(MyBackbone) """ def __init__(self, name): """ Args: name (str): the name of this registry """ self._name = name self._obj_map = {} def _do_register(self, name, obj): assert name not in self._obj_map, ( f"An object named '{name}' was already registered " f"in '{self._name}' registry!" ) self._obj_map[name] = obj def register(self, obj=None): """ Register the given object under the the name `obj.__name__`. Can be used as either a decorator or not. See docstring of this class for usage. """ if obj is None: # used as a decorator def deco(func_or_class): name = func_or_class.__name__ self._do_register(name, func_or_class) return func_or_class return deco # used as a function call name = obj.__name__ self._do_register(name, obj) def get(self, name): ret = self._obj_map.get(name) if ret is None: raise KeyError( f"No object named '{name}' found in '{self._name}' registry!" ) return ret def __contains__(self, name): return name in self._obj_map def __iter__(self): return iter(self._obj_map.items()) def keys(self): return self._obj_map.keys() DATASET_REGISTRY = Registry("dataset") ARCH_REGISTRY = Registry("arch") MODEL_REGISTRY = Registry("model") LOSS_REGISTRY = Registry("loss") METRIC_REGISTRY = Registry("metric") LR_SCHEDULER_REGISTRY = Registry("lr_scheduler") ================================================ FILE: codes/utils/resize_utils.py ================================================ import math import numpy as np import torch # matlab 'imresize' function, now only support 'bicubic' def cubic(x): absx = torch.abs(x) absx2 = absx ** 2 absx3 = absx ** 3 weight = (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + ( -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 ) * (((absx > 1) * (absx <= 2)).type_as(absx)) return weight def calculate_weights_indices( in_length, out_length, scale, kernel, kernel_width, antialiasing ): if (scale < 1) and (antialiasing): # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width kernel_width = kernel_width / scale # Output-space coordinates x = torch.linspace(1, out_length, out_length) # Input-space coordinates. Calculate the inverse mapping such that 0.5 # in output space maps to 0.5 in input space, and 0.5+scale in output # space maps to 1.5 in input space. u = x / scale + 0.5 * (1 - 1 / scale) # What is the left-most pixel that can be involved in the computation? left = torch.floor(u - kernel_width / 2) # What is the maximum number of pixels that can be involved in the # computation? Note: it's OK to use an extra pixel here; if the # corresponding weights are all zero, it will be eliminated at the end # of this function. P = math.ceil(kernel_width) + 2 # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace( 0, P - 1, P ).view(1, P).expand(out_length, P) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices # apply cubic kernel if (scale < 1) and (antialiasing): weights = scale * cubic(distance_to_center * scale) else: weights = cubic(distance_to_center) # Normalize the weights matrix so that each row sums to 1. weights_sum = torch.sum(weights, 1).view(out_length, 1) weights = weights / weights_sum.expand(out_length, P) # If a column in weights is all zero, get rid of it. only consider the first and last column. weights_zero_tmp = torch.sum((weights == 0), 0) if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): indices = indices.narrow(1, 1, P - 2) weights = weights.narrow(1, 1, P - 2) if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): indices = indices.narrow(1, 0, P - 2) weights = weights.narrow(1, 0, P - 2) weights = weights.contiguous() indices = indices.contiguous() sym_len_s = -indices.min() + 1 sym_len_e = indices.max() - in_length indices = indices + sym_len_s - 1 return weights, indices, int(sym_len_s), int(sym_len_e) def imresize(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: CHW RGB [0,1] # output: CHW RGB [0,1] w/o round is_numpy = False if isinstance(img, np.ndarray): img = torch.from_numpy(img.transpose(2, 0, 1)) is_numpy = True device = img.device is_batch = True if len(img.shape) == 3: # C, H, W img = img[None] is_batch = False B, in_C, in_H, in_W = img.size() img = img.view(-1, in_H, in_W) _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 kernel = "cubic" # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the # smallest scale factor. # Now we do not support this. # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( in_H, out_H, scale, kernel, kernel_width, antialiasing ) weights_H, indices_H = weights_H.to(device), indices_H.to(device) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( in_W, out_W, scale, kernel, kernel_width, antialiasing ) weights_W, indices_W = weights_W.to(device), indices_W.to(device) # process H dimension # symmetric copying img_aug = torch.FloatTensor(B * in_C, in_H + sym_len_Hs + sym_len_He, in_W).to( device ) img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) sym_patch = img[:, :sym_len_Hs, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device) sym_patch_inv = sym_patch.index_select(1, inv_idx) img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) sym_patch = img[:, -sym_len_He:, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long().to(device) sym_patch_inv = sym_patch.index_select(1, inv_idx) img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) out_1 = torch.FloatTensor(B * in_C, out_H, in_W).to(device) kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) out_1[:, i, :] = ( img_aug[:, idx : idx + kernel_width, :] .transpose(1, 2) .matmul(weights_H[i][None, :, None].repeat(B * in_C, 1, 1)) ).squeeze() # process W dimension # symmetric copying out_1_aug = torch.FloatTensor(B * in_C, out_H, in_W + sym_len_Ws + sym_len_We).to( device ) out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) sym_patch = out_1[:, :, :sym_len_Ws] inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device) sym_patch_inv = sym_patch.index_select(2, inv_idx) out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) sym_patch = out_1[:, :, -sym_len_We:] inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long().to(device) sym_patch_inv = sym_patch.index_select(2, inv_idx) out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) out_2 = torch.FloatTensor(B * in_C, out_H, out_W).to(device) kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) out_2[:, :, i] = ( out_1_aug[:, :, idx : idx + kernel_width].matmul( weights_W[i][None, :, None].repeat(B * in_C, 1, 1) ) ).squeeze() out_2 = out_2.contiguous().view(B, in_C, out_H, out_W) if not is_batch: out_2 = out_2[0] return out_2.cpu().numpy().transpose(1, 2, 0) if is_numpy else out_2