Repository: The-Learning-And-Vision-Atelier-LAVA/DASR Branch: main Commit: 51524af8021c Files: 33 Total size: 91.3 MB Directory structure: gitextract_01sw5xtg/ ├── LICENSE ├── README.md ├── data/ │ ├── __init__.py │ ├── benchmark.py │ ├── common.py │ ├── df2k.py │ └── multiscalesrdata.py ├── dataloader.py ├── experiment/ │ ├── blindsr_x2_bicubic_iso/ │ │ └── model/ │ │ └── model_600.pt │ ├── blindsr_x3_bicubic_iso/ │ │ └── model/ │ │ └── model_600.pt │ ├── blindsr_x4_bicubic_aniso/ │ │ └── model/ │ │ └── model_600.pt │ └── blindsr_x4_bicubic_iso/ │ └── model/ │ └── model_600.pt ├── loss/ │ ├── __init__.py │ ├── adversarial.py │ ├── discriminator.py │ └── vgg.py ├── main.py ├── main.sh ├── moco/ │ ├── __init__.py │ └── builder.py ├── model/ │ ├── __init__.py │ ├── blindsr.py │ └── common.py ├── option.py ├── quick_test.py ├── quick_test.sh ├── template.py ├── test.py ├── test.sh ├── trainer.py ├── utility.py └── utils/ ├── __init__.py └── util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 The Learning and Vision Atelier (LAVA) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # DASR Pytorch implementation of "Unsupervised Degradation Representation Learning for Blind Super-Resolution", CVPR 2021 [[arXiv]](http://arxiv.org/pdf/2104.00416) [[CVF]](https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Unsupervised_Degradation_Representation_Learning_for_Blind_Super-Resolution_CVPR_2021_paper.pdf) [[Supp]](https://openaccess.thecvf.com/content/CVPR2021/supplemental/Wang_Unsupervised_Degradation_Representation_CVPR_2021_supplemental.pdf) ## Overview

## Requirements - Python 3.6 - PyTorch == 1.1.0 - numpy - skimage - imageio - matplotlib - cv2 ## Train ### 1. Prepare training data 1.1 Download the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) dataset and the [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) dataset. 1.2 Combine the HR images from these two datasets in `your_data_path/DF2K/HR` to build the DF2K dataset. ### 2. Begin to train Run `./main.sh` to train on the DF2K dataset. Please update `dir_data` in the bash file as `your_data_path`. ## Test ### 1. Prepare test data Download [benchmark datasets](https://github.com/xinntao/BasicSR/blob/a19aac61b277f64be050cef7fe578a121d944a0e/docs/Datasets.md) (e.g., Set5, Set14 and other test sets) and prepare HR/LR images in `your_data_path/benchmark`. ### 2. Begin to test Run `./test.sh` to test on benchmark datasets. Please update `dir_data` in the bash file as `your_data_path`. ## Quick Test on An LR Image Run `./quick_test.sh` to test on an LR image. Please update `img_dir` in the bash file as `your_img_path`. ## Visualization of Degradation Representations

## Comparative Results ### Noise-Free Degradations with Isotropic Gaussian Kernels

### General Degradations with Anisotropic Gaussian Kernels and Noises

### Unseen Degradations

### Real Degradations (AIM real-world SR challenge)

## Citation ``` @InProceedings{Wang2021Unsupervised, author = {Wang, Longguang and Wang, Yingqian and Dong, Xiaoyu and Xu, Qingyu and Yang, Jungang and An, Wei and Guo, Yulan}, title = {Unsupervised Degradation Representation Learning for Blind Super-Resolution}, booktitle = {CVPR}, year = {2021}, } ``` ## Acknowledgements This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [IKC](https://github.com/yuanjunchai/IKC) and [MoCo](https://github.com/facebookresearch/moco). We thank the authors for sharing the codes. ================================================ FILE: data/__init__.py ================================================ from importlib import import_module from dataloader import MSDataLoader class Data: def __init__(self, args): self.loader_train = None if not args.test_only: module_train = import_module('data.' + args.data_train.lower()) ## load the right dataset loader module trainset = getattr(module_train, args.data_train)(args) ## load the dataset, args.data_train is the dataset name self.loader_train = MSDataLoader( args, trainset, batch_size=args.batch_size, shuffle=True, pin_memory=not args.cpu ) if args.data_test in ['Set5', 'Set14', 'B100', 'Manga109', 'Urban100']: module_test = import_module('data.benchmark') testset = getattr(module_test, 'Benchmark')(args, name=args.data_test,train=False) else: module_test = import_module('data.' + args.data_test.lower()) testset = getattr(module_test, args.data_test)(args, train=False) self.loader_test = MSDataLoader( args, testset, batch_size=1, shuffle=False, pin_memory=not args.cpu ) ================================================ FILE: data/benchmark.py ================================================ import os from data import common from data import multiscalesrdata as srdata class Benchmark(srdata.SRData): def __init__(self, args, name='', train=True): super(Benchmark, self).__init__( args, name=name, train=train, benchmark=True ) def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data,'benchmark', self.name) self.dir_hr = os.path.join(self.apath, 'HR') self.dir_lr = os.path.join(self.apath, 'LR_bicubic') self.ext = ('.png','.png') print(self.dir_hr) print(self.dir_lr) ================================================ FILE: data/common.py ================================================ import random import numpy as np import skimage.color as sc import torch def get_patch(img, patch_size=48, scale=1): th, tw = img.shape[:2] ## HR image tp = round(scale * patch_size) tx = random.randrange(0, (tw-tp)) ty = random.randrange(0, (th-tp)) return img[ty:ty + tp, tx:tx + tp, :] def set_channel(img, n_channels=3): if img.ndim == 2: img = np.expand_dims(img, axis=2) c = img.shape[2] if n_channels == 1 and c == 3: img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) elif n_channels == 3 and c == 1: img = np.concatenate([img] * n_channels, 2) return img def np2Tensor(img, rgb_range=255): np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) tensor = torch.from_numpy(np_transpose).float() tensor.mul_(rgb_range / 255) return tensor def augment(img, hflip=True, rot=True): hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 if hflip: img = img[:, ::-1, :] if vflip: img = img[::-1, :, :] if rot90: img = img.transpose(1, 0, 2) return img ================================================ FILE: data/df2k.py ================================================ import os from data import multiscalesrdata class DF2K(multiscalesrdata.SRData): def __init__(self, args, name='DF2K', train=True, benchmark=False): super(DF2K, self).__init__(args, name=name, train=train, benchmark=benchmark) def _scan(self): names_hr = super(DF2K, self)._scan() names_hr = names_hr[self.begin - 1:self.end] return names_hr def _set_filesystem(self, dir_data): super(DF2K, self)._set_filesystem(dir_data) self.dir_hr = os.path.join(self.apath, 'HR') self.dir_lr = os.path.join(self.apath, 'LR_bicubic') ================================================ FILE: data/multiscalesrdata.py ================================================ import os import glob from data import common import pickle import numpy as np import imageio import torch import torch.utils.data as data class SRData(data.Dataset): def __init__(self, args, name='', train=True, benchmark=False): self.args = args self.name = name self.train = train self.split = 'train' if train else 'test' self.do_eval = True self.benchmark = benchmark self.scale = args.scale self.idx_scale = 0 data_range = [r.split('-') for r in args.data_range.split('/')] if train: data_range = data_range[0] else: if args.test_only and len(data_range) == 1: data_range = data_range[0] else: data_range = data_range[1] self.begin, self.end = list(map(lambda x: int(x), data_range)) self._set_filesystem(args.dir_data) if args.ext.find('img') < 0: path_bin = os.path.join(self.apath, 'bin') os.makedirs(path_bin, exist_ok=True) list_hr = self._scan() if args.ext.find('bin') >= 0: # Binary files are stored in 'bin' folder # If the binary file exists, load it. If not, make it. list_hr = self._scan() self.images_hr = self._check_and_load( args.ext, list_hr, self._name_hrbin() ) else: if args.ext.find('img') >= 0 or benchmark: self.images_hr = list_hr elif args.ext.find('sep') >= 0: os.makedirs( self.dir_hr.replace(self.apath, path_bin), exist_ok=True ) self.images_hr = [] for h in list_hr: b = h.replace(self.apath, path_bin) b = b.replace(self.ext[0], '.pt') self.images_hr.append(b) self._check_and_load( args.ext, [h], b, verbose=True, load=False ) if train: self.repeat = args.test_every // (len(self.images_hr) // args.batch_size) # Below functions as used to prepare images def _scan(self): names_hr = sorted( glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) ) print(len(names_hr)) return names_hr def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data, self.name) self.dir_hr = os.path.join(self.apath, 'HR') self.dir_lr = os.path.join(self.apath, 'LR_bicubic') self.ext = ('.png', '.png') def _name_hrbin(self): return os.path.join( self.apath, 'bin', '{}_bin_HR.pt'.format(self.split) ) def _name_lrbin(self, scale): return os.path.join( self.apath, 'bin', '{}_bin_LR_X{}.pt'.format(self.split, scale) ) def _check_and_load(self, ext, l, f, verbose=True, load=True): if os.path.isfile(f) and ext.find('reset') < 0: if load: if verbose: print('Loading {}...'.format(f)) with open(f, 'rb') as _f: ret = pickle.load(_f) return ret else: return None else: if verbose: if ext.find('reset') >= 0: print('Making a new binary: {}'.format(f)) else: print('{} does not exist. Now making binary...'.format(f)) b = [{ 'name': os.path.splitext(os.path.basename(_l))[0], 'image': imageio.imread(_l) } for _l in l] with open(f, 'wb') as _f: pickle.dump(b, _f) return b def __getitem__(self, idx): hr, filename = self._load_file(idx) hr = self.get_patch(hr) hr = [common.set_channel(img, n_channels=self.args.n_colors) for img in hr] hr_tensor = [common.np2Tensor(img, rgb_range=self.args.rgb_range) for img in hr] return torch.stack(hr_tensor, 0), filename def __len__(self): if self.train: return len(self.images_hr) * self.repeat else: return len(self.images_hr) def _get_index(self, idx): if self.train: return idx % len(self.images_hr) else: return idx def _load_file(self, idx): idx = self._get_index(idx) f_hr = self.images_hr[idx] if self.args.ext.find('bin') >= 0: filename = f_hr['name'] hr = f_hr['image'] else: filename, _ = os.path.splitext(os.path.basename(f_hr)) if self.args.ext == 'img' or self.benchmark: hr = imageio.imread(f_hr) elif self.args.ext.find('sep') >= 0: with open(f_hr, 'rb') as _f: hr = np.load(_f)[0]['image'] return hr, filename def get_patch(self, hr): scale = self.scale[self.idx_scale] if self.train: out = [] hr = common.augment(hr) if not self.args.no_augment else hr # extract two patches from each image for _ in range(2): hr_patch = common.get_patch( hr, patch_size=self.args.patch_size, scale=scale ) out.append(hr_patch) else: out = [hr] return out def set_scale(self, idx_scale): self.idx_scale = idx_scale ================================================ FILE: dataloader.py ================================================ import sys import threading import queue import random import collections import torch import torch.multiprocessing as multiprocessing from torch._C import _set_worker_signal_handlers from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import _DataLoaderIter from torch.utils.data import _utils if sys.version_info[0] == 2: import Queue as queue else: import queue def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): global _use_shared_memory _use_shared_memory = True _set_worker_signal_handlers() torch.set_num_threads(1) torch.manual_seed(seed) while True: r = index_queue.get() if r is None: break idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) class _MSDataLoaderIter(_DataLoaderIter): def __init__(self, loader): self.dataset = loader.dataset self.scale = loader.scale self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queues = [ multiprocessing.Queue() for _ in range(self.num_workers) ] self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_()[0] self.workers = [ multiprocessing.Process( target=_ms_loop, args=( self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, self.scale, base_seed + i, self.worker_init_fn, i ) ) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() if self.pin_memory: maybe_device_id = torch.cuda.current_device() else: # do not initialize cuda context if not necessary maybe_device_id = None self.pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=(self.worker_result_queue, self.data_queue, maybe_device_id, self.done_event)) self.pin_memory_thread.daemon = True self.pin_memory_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers)) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices() class MSDataLoader(DataLoader): def __init__( self, args, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=True, timeout=0, worker_init_fn=None): super(MSDataLoader, self).__init__( dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=args.n_threads, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn) self.scale = args.scale def __iter__(self): return _MSDataLoaderIter(self) ================================================ FILE: experiment/blindsr_x2_bicubic_iso/model/model_600.pt ================================================ [File too large to display: 22.3 MB] ================================================ FILE: experiment/blindsr_x3_bicubic_iso/model/model_600.pt ================================================ [File too large to display: 23.0 MB] ================================================ FILE: experiment/blindsr_x4_bicubic_aniso/model/model_600.pt ================================================ [File too large to display: 22.9 MB] ================================================ FILE: experiment/blindsr_x4_bicubic_iso/model/model_600.pt ================================================ [File too large to display: 22.9 MB] ================================================ FILE: loss/__init__.py ================================================ import os from importlib import import_module import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Loss(nn.modules.loss._Loss): def __init__(self, args, ckp): super(Loss, self).__init__() print('Preparing loss function:') self.n_GPUs = args.n_GPUs self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': loss_function = nn.MSELoss() elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type == 'CE': loss_function = nn.CrossEntropyLoss() elif loss_type.find('VGG') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')( loss_type[3:], rgb_range=args.rgb_range ) elif loss_type.find('GAN') >= 0: module = import_module('loss.adversarial') loss_function = getattr(module, 'Adversarial')( args, loss_type ) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function} ) if loss_type.find('GAN') >= 0: self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.precision == 'half': self.loss_module.half() if not args.cpu and args.n_GPUs > 1: self.loss_module = nn.DataParallel( self.loss_module, range(args.n_GPUs) ) if args.load != '.': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] for i, l in enumerate(self.loss): if l['function'] is not None: loss = l['function'](sr, hr) effective_loss = l['weight'] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() elif l['type'] == 'DIS': self.log[-1, i] += self.loss[i - 1]['function'].loss loss_sum = sum(losses) if len(self.loss) > 1: self.log[-1, -1] += loss_sum.item() return loss_sum def step(self): for l in self.get_loss_module(): if hasattr(l, 'scheduler'): l.scheduler.step() def start_log(self): self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) def end_log(self, n_batches): self.log[-1].div_(n_batches) def display_loss(self, batch): n_samples = batch + 1 log = [] for l, c in zip(self.loss, self.log[-1]): log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) return ''.join(log) def plot_loss(self, apath, epoch): axis = np.linspace(1, epoch, epoch) for i, l in enumerate(self.loss): label = '{} Loss'.format(l['type']) fig = plt.figure() plt.title(label) plt.plot(axis, self.log[:, i].numpy(), label=label) plt.legend() plt.xlabel('Epochs') plt.ylabel('Loss') plt.grid(True) plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) plt.close(fig) def get_loss_module(self): if self.n_GPUs == 1: return self.loss_module else: return self.loss_module.module def save(self, apath): torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) torch.save(self.log, os.path.join(apath, 'loss_log.pt')) def load(self, apath, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} self.load_state_dict(torch.load( os.path.join(apath, 'loss.pt'), **kwargs )) self.log = torch.load(os.path.join(apath, 'loss_log.pt')) for l in self.loss_module: if hasattr(l, 'scheduler'): for _ in range(len(self.log)): l.scheduler.step() ================================================ FILE: loss/adversarial.py ================================================ import utility from model import common from loss import discriminator import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable class Adversarial(nn.Module): def __init__(self, args, gan_type): super(Adversarial, self).__init__() self.gan_type = gan_type self.gan_k = args.gan_k self.discriminator = discriminator.Discriminator(args, gan_type) if gan_type != 'WGAN_GP': self.optimizer = utility.make_optimizer(args, self.discriminator) else: self.optimizer = optim.Adam( self.discriminator.parameters(), betas=(0, 0.9), eps=1e-8, lr=1e-5 ) self.scheduler = utility.make_scheduler(args, self.optimizer) def forward(self, fake, real): fake_detach = fake.detach() self.loss = 0 for _ in range(self.gan_k): self.optimizer.zero_grad() d_fake = self.discriminator(fake_detach) d_real = self.discriminator(real) if self.gan_type == 'GAN': label_fake = torch.zeros_like(d_fake) label_real = torch.ones_like(d_real) loss_d \ = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ + F.binary_cross_entropy_with_logits(d_real, label_real) elif self.gan_type.find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True d_hat = self.discriminator(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # Discriminator update self.loss += loss_d.item() loss_d.backward() self.optimizer.step() if self.gan_type == 'WGAN': for p in self.discriminator.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k d_fake_for_g = self.discriminator(fake) if self.gan_type == 'GAN': loss_g = F.binary_cross_entropy_with_logits( d_fake_for_g, label_real ) elif self.gan_type.find('WGAN') >= 0: loss_g = -d_fake_for_g.mean() # Generator loss return loss_g def state_dict(self, *args, **kwargs): state_discriminator = self.discriminator.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # OR # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ================================================ FILE: loss/discriminator.py ================================================ from model import common import torch.nn as nn class Discriminator(nn.Module): def __init__(self, args, gan_type='GAN'): super(Discriminator, self).__init__() in_channels = 3 out_channels = 64 depth = 7 #bn = not gan_type == 'WGAN_GP' bn = True act = nn.LeakyReLU(negative_slope=0.2, inplace=True) m_features = [ common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act) ] for i in range(depth): in_channels = out_channels if i % 2 == 1: stride = 1 out_channels *= 2 else: stride = 2 m_features.append(common.BasicBlock( in_channels, out_channels, 3, stride=stride, bn=bn, act=act )) self.features = nn.Sequential(*m_features) patch_size = args.patch_size // (2**((depth + 1) // 2)) m_classifier = [ nn.Linear(out_channels * patch_size**2, 1024), act, nn.Linear(1024, 1) ] self.classifier = nn.Sequential(*m_classifier) def forward(self, x): features = self.features(x) output = self.classifier(features.view(features.size(0), -1)) return output ================================================ FILE: loss/vgg.py ================================================ from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torch.autograd import Variable class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index == '22': self.vgg = nn.Sequential(*modules[:8]) elif conv_index == '54': self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) self.vgg.requires_grad = False def forward(self, sr, hr): def _forward(x): x = self.sub_mean(x) x = self.vgg(x) return x vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = F.mse_loss(vgg_sr, vgg_hr) return loss ================================================ FILE: main.py ================================================ from option import args import torch import utility import data import model import loss from trainer import Trainer if __name__ == '__main__': torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) if checkpoint.ok: loader = data.Data(args) model = model.Model(args, checkpoint) loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, loader, model, loss, checkpoint) while not t.terminate(): t.train() checkpoint.done() ================================================ FILE: main.sh ================================================ # noise-free degradations with isotropic Gaussian blurs python main.py --dir_data='D:/LongguangWang/Data' \ --model='blindsr' \ --scale='2' \ --blur_type='iso_gaussian' \ --noise=0.0 \ --sig_min=0.2 \ --sig_max=4.0 # general degradations with anisotropic Gaussian blurs and noises python main.py --dir_data='D:/LongguangWang/Data' \ --model='blindsr' \ --scale='4' \ --blur_type='aniso_gaussian' \ --noise=25.0 \ --lambda_min=0.2 \ --lambda_max=4.0 cmd /k ================================================ FILE: moco/__init__.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved ================================================ FILE: moco/builder.py ================================================ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch import torch.nn as nn class MoCo(nn.Module): """ Build a MoCo model with: a query encoder, a key encoder, and a queue https://arxiv.org/abs/1911.05722 """ def __init__(self, base_encoder, dim=256, K=32*256, m=0.999, T=0.07, mlp=False): """ dim: feature dimension (default: 128) K: queue size; number of negative keys (default: 65536) m: moco momentum of updating key encoder (default: 0.999) T: softmax temperature (default: 0.07) """ super(MoCo, self).__init__() self.K = K self.m = m self.T = T # create the encoders # num_classes is the output fc dimension self.encoder_q = base_encoder() self.encoder_k = base_encoder() for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient # create the queue self.register_buffer("queue", torch.randn(dim, K)) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) @torch.no_grad() def _momentum_update_key_encoder(self): """ Momentum update of the key encoder """ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, keys): # gather keys before updating queue # keys = concat_all_gather(keys) batch_size = keys.shape[0] ptr = int(self.queue_ptr) assert self.K % batch_size == 0 # for simplicity # replace the keys at ptr (dequeue and enqueue) self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1) ptr = (ptr + batch_size) % self.K # move pointer self.queue_ptr[0] = ptr @torch.no_grad() def _batch_shuffle_ddp(self, x): """ Batch shuffle, for making use of BatchNorm. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # random shuffle index idx_shuffle = torch.randperm(batch_size_all).cuda() # broadcast to all gpus torch.distributed.broadcast(idx_shuffle, src=0) # index for restoring idx_unshuffle = torch.argsort(idx_shuffle) # shuffled index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this], idx_unshuffle @torch.no_grad() def _batch_unshuffle_ddp(self, x, idx_unshuffle): """ Undo batch shuffle. *** Only support DistributedDataParallel (DDP) model. *** """ # gather from all gpus batch_size_this = x.shape[0] x_gather = concat_all_gather(x) batch_size_all = x_gather.shape[0] num_gpus = batch_size_all // batch_size_this # restored index for this gpu gpu_idx = torch.distributed.get_rank() idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] return x_gather[idx_this] def forward(self, im_q, im_k): """ Input: im_q: a batch of query images im_k: a batch of key images Output: logits, targets """ if self.training: # compute query features embedding, q = self.encoder_q(im_q) # queries: NxC q = nn.functional.normalize(q, dim=1) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder _, k = self.encoder_k(im_k) # keys: NxC k = nn.functional.normalize(k, dim=1) # compute logits # Einstein sum is more intuitive # positive logits: Nx1 l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) # apply temperature logits /= self.T # labels: positive key indicators labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() # dequeue and enqueue self._dequeue_and_enqueue(k) return embedding, logits, labels else: embedding, _ = self.encoder_q(im_q) return embedding # utils @torch.no_grad() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output ================================================ FILE: model/__init__.py ================================================ import os from importlib import import_module import torch import torch.nn as nn class Model(nn.Module): def __init__(self, args, ckp): super(Model, self).__init__() print('Making model...') self.args = args self.scale = args.scale self.idx_scale = 0 self.self_ensemble = args.self_ensemble self.chop = args.chop self.precision = args.precision self.cpu = args.cpu self.device = torch.device('cpu' if args.cpu else 'cuda') self.n_GPUs = args.n_GPUs self.save_models = args.save_models self.save = args.save module = import_module('model.'+args.model) self.model = module.make_model(args).to(self.device) if args.precision == 'half': self.model.half() if not args.cpu and args.n_GPUs > 1: self.model = nn.DataParallel(self.model, range(args.n_GPUs)) self.load( ckp.dir, pre_train=args.pre_train, resume=args.resume, cpu=args.cpu ) def forward(self, x): if self.self_ensemble and not self.training: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward return self.forward_x8(x, forward_function) elif self.chop and not self.training: return self.forward_chop(x) else: return self.model(x) def get_model(self): if self.n_GPUs <= 1 or self.cpu: return self.model else: return self.model.module def state_dict(self, **kwargs): target = self.get_model() return target.state_dict(**kwargs) def save(self, apath, epoch, is_best=False): target = self.get_model() torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_latest.pt') ) if is_best: torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_best.pt') ) if self.save_models: torch.save( target.state_dict(), os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) ) def load(self, apath, pre_train='.', resume=-1, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} if resume == -1: self.get_model().load_state_dict( torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs), strict=True ) elif resume == 0: if pre_train != '.': self.get_model().load_state_dict( torch.load(pre_train, **kwargs), strict=True ) elif resume > 0: self.get_model().load_state_dict( torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs), strict=False ) def forward_chop(self, x, shave=10, min_size=160000): scale = self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) b, c, h, w = x.size() h_half, w_half = h // 2, w // 2 h_size, w_size = h_half + shave, w_half + shave lr_list = [ x[:, :, 0:h_size, 0:w_size], x[:, :, 0:h_size, (w - w_size):w], x[:, :, (h - h_size):h, 0:w_size], x[:, :, (h - h_size):h, (w - w_size):w]] if w_size * h_size < min_size: sr_list = [] for i in range(0, 4, n_GPUs): lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) sr_batch = self.model(lr_batch) sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) else: sr_list = [ self.forward_chop(patch, shave=shave, min_size=min_size) \ for patch in lr_list ] h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half h_size, w_size = scale * h_size, scale * w_size shave *= scale output = x.new(b, c, h, w) output[:, :, 0:h_half, 0:w_half] \ = sr_list[0][:, :, 0:h_half, 0:w_half] output[:, :, 0:h_half, w_half:w] \ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] output[:, :, h_half:h, 0:w_half] \ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] output[:, :, h_half:h, w_half:w] \ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output def forward_x8(self, x, forward_function): def _transform(v, op): if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) if self.precision == 'half': ret = ret.half() return ret lr_list = [x] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) sr_list = [forward_function(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) output = output_cat.mean(dim=0, keepdim=True) return output ================================================ FILE: model/blindsr.py ================================================ import torch from torch import nn import model.common as common import torch.nn.functional as F from moco.builder import MoCo def make_model(args): return BlindSR(args) class DA_conv(nn.Module): def __init__(self, channels_in, channels_out, kernel_size, reduction): super(DA_conv, self).__init__() self.channels_out = channels_out self.channels_in = channels_in self.kernel_size = kernel_size self.kernel = nn.Sequential( nn.Linear(64, 64, bias=False), nn.LeakyReLU(0.1, True), nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False) ) self.conv = common.default_conv(channels_in, channels_out, 1) self.ca = CA_layer(channels_in, channels_out, reduction) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' b, c, h, w = x[0].size() # branch 1 kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size) out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2)) out = self.conv(out.view(b, -1, h, w)) # branch 2 out = out + self.ca(x) return out class CA_layer(nn.Module): def __init__(self, channels_in, channels_out, reduction): super(CA_layer, self).__init__() self.conv_du = nn.Sequential( nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False), nn.LeakyReLU(0.1, True), nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' att = self.conv_du(x[1][:, :, None, None]) return x[0] * att class DAB(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction): super(DAB, self).__init__() self.da_conv1 = DA_conv(n_feat, n_feat, kernel_size, reduction) self.da_conv2 = DA_conv(n_feat, n_feat, kernel_size, reduction) self.conv1 = conv(n_feat, n_feat, kernel_size) self.conv2 = conv(n_feat, n_feat, kernel_size) self.relu = nn.LeakyReLU(0.1, True) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' out = self.relu(self.da_conv1(x)) out = self.relu(self.conv1(out)) out = self.relu(self.da_conv2([out, x[1]])) out = self.conv2(out) + x[0] return out class DAG(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks): super(DAG, self).__init__() self.n_blocks = n_blocks modules_body = [ DAB(conv, n_feat, kernel_size, reduction) \ for _ in range(n_blocks) ] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): ''' :param x[0]: feature map: B * C * H * W :param x[1]: degradation representation: B * C ''' res = x[0] for i in range(self.n_blocks): res = self.body[i]([res, x[1]]) res = self.body[-1](res) res = res + x[0] return res class DASR(nn.Module): def __init__(self, args, conv=common.default_conv): super(DASR, self).__init__() self.n_groups = 5 n_blocks = 5 n_feats = 64 kernel_size = 3 reduction = 8 scale = int(args.scale[0]) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std) self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1) # head module modules_head = [conv(3, n_feats, kernel_size)] self.head = nn.Sequential(*modules_head) # compress self.compress = nn.Sequential( nn.Linear(256, 64, bias=False), nn.LeakyReLU(0.1, True) ) # body modules_body = [ DAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) \ for _ in range(self.n_groups) ] modules_body.append(conv(n_feats, n_feats, kernel_size)) self.body = nn.Sequential(*modules_body) # tail modules_tail = [common.Upsampler(conv, scale, n_feats, act=False), conv(n_feats, 3, kernel_size)] self.tail = nn.Sequential(*modules_tail) def forward(self, x, k_v): k_v = self.compress(k_v) # sub mean x = self.sub_mean(x) # head x = self.head(x) # body res = x for i in range(self.n_groups): res = self.body[i]([res, k_v]) res = self.body[-1](res) res = res + x # tail x = self.tail(res) # add mean x = self.add_mean(x) return x class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.E = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1, True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.1, True), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.1, True), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.1, True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.1, True), nn.AdaptiveAvgPool2d(1), ) self.mlp = nn.Sequential( nn.Linear(256, 256), nn.LeakyReLU(0.1, True), nn.Linear(256, 256), ) def forward(self, x): fea = self.E(x).squeeze(-1).squeeze(-1) out = self.mlp(fea) return fea, out class BlindSR(nn.Module): def __init__(self, args): super(BlindSR, self).__init__() # Generator self.G = DASR(args) # Encoder self.E = MoCo(base_encoder=Encoder) def forward(self, x): if self.training: x_query = x[:, 0, ...] # b, c, h, w x_key = x[:, 1, ...] # b, c, h, w # degradation-aware represenetion learning fea, logits, labels = self.E(x_query, x_key) # degradation-aware SR sr = self.G(x_query, fea) return sr, logits, labels else: # degradation-aware represenetion learning fea = self.E(x, x) # degradation-aware SR sr = self.G(x, fea) return sr ================================================ FILE: model/common.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F def default_conv(in_channels, out_channels, kernel_size, bias=True): return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias) class MeanShift(nn.Conv2d): def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) self.weight.data.div_(std.view(3, 1, 1, 1)) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) self.bias.data.div_(std) self.weight.requires_grad = False self.bias.requires_grad = False class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feat, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feat, 4 * n_feat, 3, bias)) m.append(nn.PixelShuffle(2)) if act: m.append(act()) elif scale == 3: m.append(conv(n_feat, 9 * n_feat, 3, bias)) m.append(nn.PixelShuffle(3)) if act: m.append(act()) else: raise NotImplementedError super(Upsampler, self).__init__(*m) ================================================ FILE: option.py ================================================ import argparse import template parser = argparse.ArgumentParser(description='EDSR and MDSR') parser.add_argument('--debug', action='store_true', help='Enables debug mode') parser.add_argument('--template', default='.', help='You can set various templates in option.py') # Hardware specifications parser.add_argument('--n_threads', type=int, default=4, help='number of threads for data loading') parser.add_argument('--cpu', type=bool, default=False, help='use cpu only') parser.add_argument('--n_GPUs', type=int, default=2, help='number of GPUs') parser.add_argument('--seed', type=int, default=1, help='random seed') # Data specifications parser.add_argument('--dir_data', type=str, default='D:/LongguangWang/Data', help='dataset directory') parser.add_argument('--dir_demo', type=str, default='../test', help='demo image directory') parser.add_argument('--data_train', type=str, default='DF2K', help='train dataset name') parser.add_argument('--data_test', type=str, default='Set14', help='test dataset name') parser.add_argument('--data_range', type=str, default='1-3450/801-810', help='train/test data range') parser.add_argument('--ext', type=str, default='sep', help='dataset file extension') parser.add_argument('--scale', type=str, default='4', help='super resolution scale') parser.add_argument('--patch_size', type=int, default=48, help='output patch size') parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB') parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') parser.add_argument('--no_augment', action='store_true', help='do not use data augmentation') # Degradation specifications parser.add_argument('--blur_kernel', type=int, default=21, help='size of blur kernels') parser.add_argument('--blur_type', type=str, default='iso_gaussian', help='blur types (iso_gaussian | aniso_gaussian)') parser.add_argument('--mode', type=str, default='bicubic', help='downsampler (bicubic | s-fold)') parser.add_argument('--noise', type=float, default=0.0, help='noise level') ## isotropic Gaussian blur parser.add_argument('--sig_min', type=float, default=0.2, help='minimum sigma of isotropic Gaussian blurs') parser.add_argument('--sig_max', type=float, default=4.0, help='maximum sigma of isotropic Gaussian blurs') parser.add_argument('--sig', type=float, default=4.0, help='specific sigma of isotropic Gaussian blurs') ## anisotropic Gaussian blur parser.add_argument('--lambda_min', type=float, default=0.2, help='minimum value for the eigenvalue of anisotropic Gaussian blurs') parser.add_argument('--lambda_max', type=float, default=4.0, help='maximum value for the eigenvalue of anisotropic Gaussian blurs') parser.add_argument('--lambda_1', type=float, default=0.2, help='one eigenvalue of anisotropic Gaussian blurs') parser.add_argument('--lambda_2', type=float, default=4.0, help='another eigenvalue of anisotropic Gaussian blurs') parser.add_argument('--theta', type=float, default=0.0, help='rotation angle of anisotropic Gaussian blurs [0, 180]') # Model specifications parser.add_argument('--model', default='blindsr', help='model name') parser.add_argument('--pre_train', type=str, default= '.', help='pre-trained model directory') parser.add_argument('--extend', type=str, default='.', help='pre-trained model directory') parser.add_argument('--shift_mean', default=True, help='subtract pixel mean from the input') parser.add_argument('--dilation', action='store_true', help='use dilated convolution') parser.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test (single | half)') # Training specifications parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') parser.add_argument('--epochs_encoder', type=int, default=100, help='number of epochs to train the degradation encoder') parser.add_argument('--epochs_sr', type=int, default=500, help='number of epochs to train the whole network') parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') parser.add_argument('--self_ensemble', action='store_true', help='use self-ensemble method for test') parser.add_argument('--test_only', action='store_true', help='set this option to test the model') # Optimization specifications parser.add_argument('--lr_encoder', type=float, default=1e-3, help='learning rate to train the degradation encoder') parser.add_argument('--lr_sr', type=float, default=1e-4, help='learning rate to train the whole network') parser.add_argument('--lr_decay_encoder', type=int, default=60, help='learning rate decay per N epochs') parser.add_argument('--lr_decay_sr', type=int, default=125, help='learning rate decay per N epochs') parser.add_argument('--decay_type', type=str, default='step', help='learning rate decay type') parser.add_argument('--gamma_encoder', type=float, default=0.1, help='learning rate decay factor for step decay') parser.add_argument('--gamma_sr', type=float, default=0.5, help='learning rate decay factor for step decay') parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') parser.add_argument('--beta1', type=float, default=0.9, help='ADAM beta1') parser.add_argument('--beta2', type=float, default=0.999, help='ADAM beta2') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') parser.add_argument('--start_epoch', type=int, default=0, help='resume from the snapshot, and the start_epoch') # Loss specifications parser.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') parser.add_argument('--skip_threshold', type=float, default='1e6', help='skipping batch that has large error') # Log specifications parser.add_argument('--save', type=str, default='blindsr', help='file name to save') parser.add_argument('--load', type=str, default='.', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') parser.add_argument('--save_models', action='store_true', help='save all intermediate models') parser.add_argument('--print_every', type=int, default=200, help='how many batches to wait before logging training status') parser.add_argument('--save_results', default=False, help='save output results') args = parser.parse_args() template.set_template(args) args.scale = list(map(lambda x: float(x), args.scale.split('+'))) ================================================ FILE: quick_test.py ================================================ from model.blindsr import BlindSR import torch import numpy as np import imageio import argparse import os import utility import cv2 def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--img_dir', type=str, default='D:/LongguangWang/Data/test.png', help='image directory') parser.add_argument('--scale', type=str, default='2', help='super resolution scale') parser.add_argument('--resume', type=int, default=600, help='resume from specific checkpoint') parser.add_argument('--blur_type', type=str, default='iso_gaussian', help='blur types (iso_gaussian | aniso_gaussian)') return parser.parse_args() def main(): args = parse_args() if args.blur_type == 'iso_gaussian': dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_iso' elif args.blur_type == 'aniso_gaussian': dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_aniso' # path to save sr images save_dir = dir + '/results' if not os.path.exists(save_dir): os.mkdir(save_dir) DASR = BlindSR(args).cuda() DASR.load_state_dict(torch.load(dir + '/model/model_' + str(args.resume) + '.pt'), strict=False) DASR.eval() lr = imageio.imread(args.img_dir) lr = np.ascontiguousarray(lr.transpose((2, 0, 1))) lr = torch.from_numpy(lr).float().cuda().unsqueeze(0).unsqueeze(0) # inference sr = DASR(lr[:, 0, ...]) sr = utility.quantize(sr, 255.0) # save sr results img_name = args.img_dir.split('.png')[0].split('/')[-1] sr = np.array(sr.squeeze(0).permute(1, 2, 0).data.cpu()) sr = sr[:, :, [2, 1, 0]] cv2.imwrite(save_dir + '/' + img_name + '_sr.png', sr) if __name__ == '__main__': with torch.no_grad(): main() ================================================ FILE: quick_test.sh ================================================ # super-resolve an LR image (x2) using the model trained on noise-free degradations with isotropic Gaussian blurs python quick_test.py --img_dir='D:/LongguangWang/Data/test.png' \ --scale='2' \ --resume=600 \ --blur_type='iso_gaussian' # super-resolve an LR image (x4) using the model trained on general degradations with anisotropic Gaussian blurs and noises python quick_test.py --img_dir='D:/LongguangWang/Data/test.png' \ --scale='4' \ --resume=600 \ --blur_type='aniso_gaussian' cmd /k ================================================ FILE: template.py ================================================ def set_template(args): # Set the templates here if args.template.find('jpeg') >= 0: args.data_train = 'DIV2K_jpeg' args.data_test = 'DIV2K_jpeg' args.epochs = 200 args.lr_decay = 100 if args.template.find('EDSR_paper') >= 0: args.model = 'EDSR' args.n_resblocks = 32 args.n_feats = 256 args.res_scale = 0.1 if args.template.find('MDSR') >= 0: args.model = 'MDSR' args.patch_size = 48 args.epochs = 650 if args.template.find('DDBPN') >= 0: args.model = 'DDBPN' args.patch_size = 128 args.scale = '4' args.data_test = 'Set5' args.batch_size = 20 args.epochs = 1000 args.lr_decay = 500 args.gamma = 0.1 args.weight_decay = 1e-4 args.loss = '1*MSE' if args.template.find('GAN') >= 0: args.epochs = 200 args.lr = 5e-5 args.lr_decay = 150 if args.template.find('RCAN') >= 0: args.model = 'RCAN' args.n_resgroups = 10 args.n_resblocks = 20 args.n_feats = 64 args.chop = True ================================================ FILE: test.py ================================================ from option import args import torch import utility import data import model import loss from trainer import Trainer if __name__ == '__main__': torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) if checkpoint.ok: loader = data.Data(args) model = model.Model(args, checkpoint) loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, loader, model, loss, checkpoint) while not t.terminate(): t.test() checkpoint.done() ================================================ FILE: test.sh ================================================ # noise-free degradations with isotropic Gaussian blurs python test.py --test_only \ --dir_data='D:/LongguangWang/Data' \ --data_test='Set14' \ --model='blindsr' \ --scale='2' \ --resume=600 \ --blur_type='iso_gaussian' \ --noise=0.0 \ --sig=1.2 # general degradations with anisotropic Gaussian blurs and noises python test.py --test_only \ --dir_data='D:/LongguangWang/Data' \ --data_test='Set14' \ --model='blindsr' \ --scale='4' \ --resume=600 \ --blur_type='aniso_gaussian' \ --noise=10.0 \ --theta=0.0 \ --lambda_1=0.2 \ --lambda_2=4.0 cmd /k ================================================ FILE: trainer.py ================================================ import os import utility import torch from decimal import Decimal import torch.nn.functional as F from utils import util class Trainer(): def __init__(self, args, loader, my_model, my_loss, ckp): self.args = args self.scale = args.scale self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.model_E = torch.nn.DataParallel(self.model.get_model().E, range(self.args.n_GPUs)) self.loss = my_loss self.contrast_loss = torch.nn.CrossEntropyLoss().cuda() self.optimizer = utility.make_optimizer(args, self.model) self.scheduler = utility.make_scheduler(args, self.optimizer) if self.args.load != '.': self.optimizer.load_state_dict( torch.load(os.path.join(ckp.dir, 'optimizer.pt')) ) for _ in range(len(ckp.log)): self.scheduler.step() def train(self): self.scheduler.step() self.loss.step() epoch = self.scheduler.last_epoch + 1 # lr stepwise if epoch <= self.args.epochs_encoder: lr = self.args.lr_encoder * (self.args.gamma_encoder ** (epoch // self.args.lr_decay_encoder)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr else: lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))) self.loss.start_log() self.model.train() degrade = util.SRMDPreprocessing( self.scale[0], kernel_size=self.args.blur_kernel, blur_type=self.args.blur_type, sig_min=self.args.sig_min, sig_max=self.args.sig_max, lambda_min=self.args.lambda_min, lambda_max=self.args.lambda_max, noise=self.args.noise ) timer = utility.timer() losses_contrast, losses_sr = utility.AverageMeter(), utility.AverageMeter() for batch, (hr, _, idx_scale) in enumerate(self.loader_train): hr = hr.cuda() # b, n, c, h, w lr, b_kernels = degrade(hr) # bn, c, h, w self.optimizer.zero_grad() timer.tic() # forward ## train degradation encoder if epoch <= self.args.epochs_encoder: _, output, target = self.model_E(im_q=lr[:,0,...], im_k=lr[:,1,...]) loss_constrast = self.contrast_loss(output, target) loss = loss_constrast losses_contrast.update(loss_constrast.item()) ## train the whole network else: sr, output, target = self.model(lr) loss_SR = self.loss(sr, hr[:,0,...]) loss_constrast = self.contrast_loss(output, target) loss = loss_constrast + loss_SR losses_sr.update(loss_SR.item()) losses_contrast.update(loss_constrast.item()) # backward loss.backward() self.optimizer.step() timer.hold() if epoch <= self.args.epochs_encoder: if (batch + 1) % self.args.print_every == 0: self.ckp.write_log( 'Epoch: [{:03d}][{:04d}/{:04d}]\t' 'Loss [contrastive loss: {:.3f}]\t' 'Time [{:.1f}s]'.format( epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), losses_contrast.avg, timer.release() )) else: if (batch + 1) % self.args.print_every == 0: self.ckp.write_log( 'Epoch: [{:04d}][{:04d}/{:04d}]\t' 'Loss [SR loss:{:.3f} | contrastive loss: {:.3f}]\t' 'Time [{:.1f}s]'.format( epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), losses_sr.avg, losses_contrast.avg, timer.release(), )) self.loss.end_log(len(self.loader_train)) # save model target = self.model.get_model() model_dict = target.state_dict() keys = list(model_dict.keys()) for key in keys: if 'E.encoder_k' in key or 'queue' in key: del model_dict[key] torch.save( model_dict, os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch)) ) def test(self): self.ckp.write_log('\nEvaluation:') self.ckp.add_log(torch.zeros(1, len(self.scale))) self.model.eval() timer_test = utility.timer() with torch.no_grad(): for idx_scale, scale in enumerate(self.scale): self.loader_test.dataset.set_scale(idx_scale) eval_psnr = 0 eval_ssim = 0 degrade = util.SRMDPreprocessing( self.scale[0], kernel_size=self.args.blur_kernel, blur_type=self.args.blur_type, sig=self.args.sig, lambda_1=self.args.lambda_1, lambda_2=self.args.lambda_2, theta=self.args.theta, noise=self.args.noise ) for idx_img, (hr, filename, _) in enumerate(self.loader_test): hr = hr.cuda() # b, 1, c, h, w hr = self.crop_border(hr, scale) lr, _ = degrade(hr, random=False) # b, 1, c, h, w hr = hr[:, 0, ...] # b, c, h, w # inference timer_test.tic() sr = self.model(lr[:, 0, ...]) timer_test.hold() sr = utility.quantize(sr, self.args.rgb_range) hr = utility.quantize(hr, self.args.rgb_range) # metrics eval_psnr += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, benchmark=self.loader_test.dataset.benchmark ) eval_ssim += utility.calc_ssim( sr, hr, scale, benchmark=self.loader_test.dataset.benchmark ) # save results if self.args.save_results: save_list = [sr] filename = filename[0] self.ckp.save_results(filename, save_list, scale) self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test) self.ckp.write_log( '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f}'.format( self.args.resume, self.args.data_test, scale, eval_psnr / len(self.loader_test), eval_ssim / len(self.loader_test), )) def crop_border(self, img_hr, scale): b, n, c, h, w = img_hr.size() img_hr = img_hr[:, :, :, :int(h//scale*scale), :int(w//scale*scale)] return img_hr def terminate(self): if self.args.test_only: self.test() return True else: epoch = self.scheduler.last_epoch + 1 return epoch >= self.args.epochs_encoder + self.args.epochs_sr ================================================ FILE: utility.py ================================================ import os import math import time import datetime import matplotlib.pyplot as plt import numpy as np import scipy.misc as misc import cv2 import torch import torch.optim as optim import torch.optim.lr_scheduler as lrs class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class timer(): def __init__(self): self.acc = 0 self.tic() def tic(self): self.t0 = time.time() def toc(self): return time.time() - self.t0 def hold(self): self.acc += self.toc() def release(self): ret = self.acc self.acc = 0 return ret def reset(self): self.acc = 0 class checkpoint(): def __init__(self, args): self.args = args self.ok = True self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if args.blur_type == 'iso_gaussian': self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso' elif args.blur_type == 'aniso_gaussian': self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso' def _make_dir(path): if not os.path.exists(path): os.makedirs(path) _make_dir(self.dir) _make_dir(self.dir + '/model') _make_dir(self.dir + '/results') open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' self.log_file = open(self.dir + '/log.txt', open_type) with open(self.dir + '/config.txt', open_type) as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') def save(self, trainer, epoch, is_best=False): trainer.model.save(self.dir, epoch, is_best=is_best) trainer.loss.save(self.dir) trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) torch.save( trainer.optimizer.state_dict(), os.path.join(self.dir, 'optimizer.pt') ) def add_log(self, log): self.log = torch.cat([self.log, log]) def write_log(self, log, refresh=False): print(log) self.log_file.write(log + '\n') if refresh: self.log_file.close() self.log_file = open(self.dir + '/log.txt', 'a') def done(self): self.log_file.close() def plot_psnr(self, epoch): axis = np.linspace(1, epoch, epoch) label = 'SR on {}'.format(self.args.data_test) fig = plt.figure() plt.title(label) for idx_scale, scale in enumerate(self.args.scale): plt.plot( axis, self.log[:, idx_scale].numpy(), label='Scale {}'.format(scale) ) plt.legend() plt.xlabel('Epochs') plt.ylabel('PSNR') plt.grid(True) plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) plt.close(fig) def save_results(self, filename, save_list, scale): filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) normalized = save_list[0][0].data.mul(255 / self.args.rgb_range) ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() misc.imsave('{}{}.png'.format(filename, 'SR'), ndarr) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): diff = (sr - hr).data.div(rgb_range) if benchmark: shave = scale if diff.size(1) > 1: convert = diff.new(1, 3, 1, 1) convert[0, 0, 0, 0] = 65.738 convert[0, 1, 0, 0] = 129.057 convert[0, 2, 0, 0] = 25.064 diff.mul_(convert).div_(256) diff = diff.sum(dim=1, keepdim=True) else: shave = scale + 6 import math shave = math.ceil(shave) valid = diff[:, :, shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) def calc_ssim(img1, img2, scale=2, benchmark=False): '''calculate SSIM the same outputs as MATLAB's img1, img2: [0, 255] ''' if benchmark: border = math.ceil(scale) else: border = math.ceil(scale) + 6 img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy() img1 = np.transpose(img1, (1, 2, 0)) img2 = img2.data.squeeze().cpu().numpy() img2 = np.transpose(img2, (1, 2, 0)) img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0 img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0 if not img1.shape == img2.shape: raise ValueError('Input images must have the same dimensions.') h, w = img1.shape[:2] img1_y = img1_y[border:h - border, border:w - border] img2_y = img2_y[border:h - border, border:w - border] if img1_y.ndim == 2: return ssim(img1_y, img2_y) elif img1.ndim == 3: if img1.shape[2] == 3: ssims = [] for i in range(3): ssims.append(ssim(img1, img2)) return np.array(ssims).mean() elif img1.shape[2] == 1: return ssim(np.squeeze(img1), np.squeeze(img2)) else: raise ValueError('Wrong input image dimensions.') def ssim(img1, img2): C1 = (0.01 * 255) ** 2 C2 = (0.03 * 255) ** 2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1 ** 2 mu2_sq = mu2 ** 2 mu1_mu2 = mu1 * mu2 sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() def make_optimizer(args, my_model): trainable = filter(lambda x: x.requires_grad, my_model.parameters()) if args.optimizer == 'SGD': optimizer_function = optim.SGD kwargs = {'momentum': args.momentum} elif args.optimizer == 'ADAM': optimizer_function = optim.Adam kwargs = { 'betas': (args.beta1, args.beta2), 'eps': args.epsilon } elif args.optimizer == 'RMSprop': optimizer_function = optim.RMSprop kwargs = {'eps': args.epsilon} kwargs['weight_decay'] = args.weight_decay return optimizer_function(trainable, **kwargs) def make_scheduler(args, my_optimizer): if args.decay_type == 'step': scheduler = lrs.StepLR( my_optimizer, step_size=args.lr_decay_sr, gamma=args.gamma_sr, ) elif args.decay_type.find('step') >= 0: milestones = args.decay_type.split('_') milestones.pop(0) milestones = list(map(lambda x: int(x), milestones)) scheduler = lrs.MultiStepLR( my_optimizer, milestones=milestones, gamma=args.gamma ) scheduler.step(args.start_epoch - 1) return scheduler ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/util.py ================================================ import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def cal_sigma(sig_x, sig_y, radians): sig_x = sig_x.view(-1, 1, 1) sig_y = sig_y.view(-1, 1, 1) radians = radians.view(-1, 1, 1) D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1) U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2), torch.cat([radians.sin(), radians.cos()], 2)], 1) sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2))) return sigma def anisotropic_gaussian_kernel(batch, kernel_size, covar): ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2 xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) xy = torch.stack([xx, yy], -1).view(batch, -1, 2) inverse_sigma = torch.inverse(covar) kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size) return kernel / kernel.sum([1, 2], keepdim=True) def isotropic_gaussian_kernel(batch, kernel_size, sigma): ax = torch.arange(kernel_size).float().cuda() - kernel_size//2 xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1) kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2)) return kernel / kernel.sum([1,2], keepdim=True) def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0): theta = torch.rand(batch).cuda() * math.pi lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min covar = cal_sigma(lambda_1, lambda_2, theta) kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar) return kernel def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0): theta = torch.ones(1).cuda() * theta / 180 * math.pi lambda_1 = torch.ones(1).cuda() * lambda_1 lambda_2 = torch.ones(1).cuda() * lambda_2 covar = cal_sigma(lambda_1, lambda_2, theta) kernel = anisotropic_gaussian_kernel(1, kernel_size, covar) return kernel def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0): x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min k = isotropic_gaussian_kernel(batch, kernel_size, x) return k def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0): x = torch.ones(1).cuda() * sig k = isotropic_gaussian_kernel(1, kernel_size, x) return k def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0): if blur_type == 'iso_gaussian': return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max) elif blur_type == 'aniso_gaussian': return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max) def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0): if blur_type == 'iso_gaussian': return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig) elif blur_type == 'aniso_gaussian': return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta) # implementation of matlab bicubic interpolation in pytorch class bicubic(nn.Module): def __init__(self): super(bicubic, self).__init__() def cubic(self, x): absx = torch.abs(x) absx2 = torch.abs(x) * torch.abs(x) absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x) condition1 = (absx <= 1).to(torch.float32) condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32) f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2 return f def contribute(self, in_size, out_size, scale): kernel_width = 4 if scale < 1: kernel_width = 4 / scale x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda() x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda() u0 = x0 / scale + 0.5 * (1 - 1 / scale) u1 = x1 / scale + 0.5 * (1 - 1 / scale) left0 = torch.floor(u0 - kernel_width / 2) left1 = torch.floor(u1 - kernel_width / 2) P = np.ceil(kernel_width) + 2 indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda() mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0) mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0) if scale < 1: weight0 = scale * self.cubic(mid0 * scale) weight1 = scale * self.cubic(mid1 * scale) else: weight0 = self.cubic(mid0) weight1 = self.cubic(mid1) weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2)) weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2)) indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0) indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0) kill0 = torch.eq(weight0, 0)[0][0] kill1 = torch.eq(weight1, 0)[0][0] weight0 = weight0[:, :, kill0 == 0] weight1 = weight1[:, :, kill1 == 0] indice0 = indice0[:, :, kill0 == 0] indice1 = indice1[:, :, kill1 == 0] return weight0, weight1, indice0, indice1 def forward(self, input, scale=1/4): b, c, h, w = input.shape weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale) weight0 = weight0[0] weight1 = weight1[0] indice0 = indice0[0].long() indice1 = indice1[0].long() out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4)) out = (torch.sum(out, dim=3)) A = out.permute(0, 1, 3, 2) out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4)) out = out.sum(3).permute(0, 1, 3, 2) return out class Gaussin_Kernel(object): def __init__(self, kernel_size=21, blur_type='iso_gaussian', sig=2.6, sig_min=0.2, sig_max=4.0, lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0): self.kernel_size = kernel_size self.blur_type = blur_type self.sig = sig self.sig_min = sig_min self.sig_max = sig_max self.lambda_1 = lambda_1 self.lambda_2 = lambda_2 self.theta = theta self.lambda_min = lambda_min self.lambda_max = lambda_max def __call__(self, batch, random): # random kernel if random == True: return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type, sig_min=self.sig_min, sig_max=self.sig_max, lambda_min=self.lambda_min, lambda_max=self.lambda_max) # stable kernel else: return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type, sig=self.sig, lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta) class BatchBlur(nn.Module): def __init__(self, kernel_size=21): super(BatchBlur, self).__init__() self.kernel_size = kernel_size if kernel_size % 2 == 1: self.pad = nn.ReflectionPad2d(kernel_size//2) else: self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1)) def forward(self, input, kernel): B, C, H, W = input.size() input_pad = self.pad(input) H_p, W_p = input_pad.size()[-2:] if len(kernel.size()) == 2: input_CBHW = input_pad.view((C * B, 1, H_p, W_p)) kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size)) return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W)) else: input_CBHW = input_pad.view((1, C * B, H_p, W_p)) kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size)) kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size)) return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W)) class SRMDPreprocessing(object): def __init__(self, scale, mode='bicubic', kernel_size=21, blur_type='iso_gaussian', sig=2.6, sig_min=0.2, sig_max=4.0, lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0, noise=0.0 ): ''' # sig, sig_min and sig_max are used for isotropic Gaussian blurs During training phase (random=True): the width of the blur kernel is randomly selected from [sig_min, sig_max] During test phase (random=False): the width of the blur kernel is set to sig # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs During training phase (random=True): the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max] the angle value is randomly selected from [0, pi] During test phase (random=False): the eigenvalues of the covariance are set to lambda_1 and lambda_2 the angle value is set to theta ''' self.kernel_size = kernel_size self.scale = scale self.mode = mode self.noise = noise self.gen_kernel = Gaussin_Kernel( kernel_size=kernel_size, blur_type=blur_type, sig=sig, sig_min=sig_min, sig_max=sig_max, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max ) self.blur = BatchBlur(kernel_size=kernel_size) self.bicubic = bicubic() def __call__(self, hr_tensor, random=True): with torch.no_grad(): # only downsampling if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0: B, N, C, H, W = hr_tensor.size() hr_blured = hr_tensor.view(-1, C, H, W) b_kernels = None # gaussian blur + downsampling else: B, N, C, H, W = hr_tensor.size() b_kernels = self.gen_kernel(B, random) # B degradations # blur hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels) hr_blured = hr_blured.view(-1, C, H, W) # BN, C, H, W # downsampling if self.mode == 'bicubic': lr_blured = self.bicubic(hr_blured, scale=1/self.scale) elif self.mode == 's-fold': lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0] # add noise if self.noise > 0: _, C, H_lr, W_lr = lr_blured.size() noise_level = torch.rand(B, 1, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise noise = torch.randn_like(lr_blured).view(-1, N, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr) lr_blured.add_(noise) lr_blured = torch.clamp(lr_blured.round(), 0, 255) return lr_blured.view(B, N, C, H//int(self.scale), W//int(self.scale)), b_kernels