Repository: HarukiYqM/Non-Local-Sparse-Attention Branch: main Commit: fb3d42255832 Files: 43 Total size: 117.5 KB Directory structure: gitextract_22vbrkxt/ ├── README.md └── src/ ├── __init__.py ├── data/ │ ├── __init__.py │ ├── benchmark.py │ ├── common.py │ ├── demo.py │ ├── div2k.py │ ├── div2kjpeg.py │ ├── sr291.py │ ├── srdata.py │ └── video.py ├── dataloader.py ├── demo.sh ├── loss/ │ ├── __init__.py │ ├── __loss__.py │ ├── adversarial.py │ ├── demo.sh │ ├── discriminator.py │ ├── hash.py │ └── vgg.py ├── main.py ├── model/ │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── attention.py │ ├── common.py │ ├── ddbpn.py │ ├── edsr.py │ ├── mdsr.py │ ├── mssr.py │ ├── nlsn.py │ ├── rcan.py │ ├── rdn.py │ ├── utils/ │ │ ├── __init__.py │ │ └── tools.py │ └── vdsr.py ├── option.py ├── template.py ├── trainer.py ├── utility.py ├── utils/ │ ├── __init__.py │ └── tools.py └── videotester.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Image Super-Resolution with Non-Local Sparse Attention This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-Local Sparse Attention", CVPR2021, [[Link]](https://openaccess.thecvf.com/content/CVPR2021/papers/Mei_Image_Super-Resolution_With_Non-Local_Sparse_Attention_CVPR_2021_paper.pdf) The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and test on Ubuntu 18.04 environment (Python3.6, PyTorch >= 1.1.0) with V100 GPUs. ## Contents 1. [Introduction](#introduction) 2. [Train](#train) 3. [Test](#test) 5. [Citation](#citation) 6. [Acknowledgements](#acknowledgements) ## Introduction Both Non-Local (NL) operation and sparse representa-tion are crucial for Single Image Super-Resolution (SISR).In this paper, we investigate their combinations and proposea novel Non-Local Sparse Attention (NLSA) with dynamicsparse attention pattern. NLSA is designed to retain long-range modeling capability from NL operation while enjoying robustness and high-efficiency of sparse representation.Specifically, NLSA rectifies non-local attention with spherical locality sensitive hashing (LSH) that partitions the input space into hash buckets of related features. For everyquery signal, NLSA assigns a bucket to it and only computes attention within the bucket. The resulting sparse attention prevents the model from attending to locations thatare noisy and less-informative, while reducing the computa-tional cost from quadratic to asymptotic linear with respectto the spatial size. Extensive experiments validate the effectiveness and efficiency of NLSA. With a few non-local sparseattention modules, our architecture, called non-local sparsenetwork (NLSN), reaches state-of-the-art performance forSISR quantitatively and qualitatively. ![Non-Local Sparse Attention](/Figs/Attention.png) Non-Local Sparse Attention. ![NLSN](/Figs/NLSN.png) Non-Local Sparse Network. ## Train ### Prepare training data 1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar). 2. Specify '--dir_data' based on the HR and LR images path. For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch). ### Begin to train 1. (optional) Download pretrained models for our paper. Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing) 2. Cd to 'src', run the following script to train models. **Example command is in the file 'demo.sh'.** ```bash # Example X2 SR python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K ``` ## Test ### Quick start 1. Download benchmark datasets from [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) 1. (optional) Download pretrained models for our paper. All the models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing) 2. Cd to 'src', run the following scripts. **Example command is in the file 'demo.sh'.** ```bash # No self-ensemble: NLSN # Example X2 SR python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only ``` ## Citation If you find the code helpful in your resarch or work, please cite the following papers. ``` @InProceedings{Mei_2021_CVPR, author = {Mei, Yiqun and Fan, Yuchen and Zhou, Yuqian}, title = {Image Super-Resolution With Non-Local Sparse Attention}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2021}, pages = {3517-3526} } @InProceedings{Lim_2017_CVPR_Workshops, author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu}, title = {Enhanced Deep Residual Networks for Single Image Super-Resolution}, booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, month = {July}, year = {2017} } ``` ## Acknowledgements This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch). We thank the authors for sharing their codes. ================================================ FILE: src/__init__.py ================================================ ================================================ FILE: src/data/__init__.py ================================================ from importlib import import_module #from dataloader import MSDataLoader from torch.utils.data import dataloader from torch.utils.data import ConcatDataset # This is a simple wrapper function for ConcatDataset class MyConcatDataset(ConcatDataset): def __init__(self, datasets): super(MyConcatDataset, self).__init__(datasets) self.train = datasets[0].train def set_scale(self, idx_scale): for d in self.datasets: if hasattr(d, 'set_scale'): d.set_scale(idx_scale) class Data: def __init__(self, args): self.loader_train = None if not args.test_only: datasets = [] for d in args.data_train: module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' m = import_module('data.' + module_name.lower()) datasets.append(getattr(m, module_name)(args, name=d)) self.loader_train = dataloader.DataLoader( MyConcatDataset(datasets), batch_size=args.batch_size, shuffle=True, pin_memory=not args.cpu, num_workers=args.n_threads, ) self.loader_test = [] for d in args.data_test: if d in ['Set5', 'Set14', 'B100', 'Urban100']: m = import_module('data.benchmark') testset = getattr(m, 'Benchmark')(args, train=False, name=d) else: module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' m = import_module('data.' + module_name.lower()) testset = getattr(m, module_name)(args, train=False, name=d) self.loader_test.append( dataloader.DataLoader( testset, batch_size=1, shuffle=False, pin_memory=not args.cpu, num_workers=args.n_threads, ) ) ================================================ FILE: src/data/benchmark.py ================================================ import os from data import common from data import srdata import numpy as np import torch import torch.utils.data as data class Benchmark(srdata.SRData): def __init__(self, args, name='', train=True, benchmark=True): super(Benchmark, self).__init__( args, name=name, train=train, benchmark=True ) def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data, 'benchmark', self.name) self.dir_hr = os.path.join(self.apath, 'HR') if self.input_large: self.dir_lr = os.path.join(self.apath, 'LR_bicubicL') else: self.dir_lr = os.path.join(self.apath, 'LR_bicubic') self.ext = ('', '.png') ================================================ FILE: src/data/common.py ================================================ import random import numpy as np import skimage.color as sc import torch def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): ih, iw = args[0].shape[:2] if not input_large: p = scale if multi else 1 tp = p * patch_size ip = tp // scale else: tp = patch_size ip = patch_size ix = random.randrange(0, iw - ip + 1) iy = random.randrange(0, ih - ip + 1) if not input_large: tx, ty = scale * ix, scale * iy else: tx, ty = ix, iy ret = [ args[0][iy:iy + ip, ix:ix + ip, :], *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] ] return ret def set_channel(*args, n_channels=3): def _set_channel(img): if img.ndim == 2: img = np.expand_dims(img, axis=2) c = img.shape[2] if n_channels == 1 and c == 3: img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) elif n_channels == 3 and c == 1: img = np.concatenate([img] * n_channels, 2) return img return [_set_channel(a) for a in args] def np2Tensor(*args, rgb_range=255): def _np2Tensor(img): np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) tensor = torch.from_numpy(np_transpose).float() tensor.mul_(rgb_range / 255) return tensor return [_np2Tensor(a) for a in args] def augment(*args, hflip=True, rot=True): hflip = hflip and random.random() < 0.5 vflip = rot and random.random() < 0.5 rot90 = rot and random.random() < 0.5 def _augment(img): if hflip: img = img[:, ::-1, :] if vflip: img = img[::-1, :, :] if rot90: img = img.transpose(1, 0, 2) return img return [_augment(a) for a in args] ================================================ FILE: src/data/demo.py ================================================ import os from data import common import numpy as np import imageio import torch import torch.utils.data as data class Demo(data.Dataset): def __init__(self, args, name='Demo', train=False, benchmark=False): self.args = args self.name = name self.scale = args.scale self.idx_scale = 0 self.train = False self.benchmark = benchmark self.filelist = [] for f in os.listdir(args.dir_demo): if f.find('.png') >= 0 or f.find('.jp') >= 0: self.filelist.append(os.path.join(args.dir_demo, f)) self.filelist.sort() def __getitem__(self, idx): filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0] lr = imageio.imread(self.filelist[idx]) lr, = common.set_channel(lr, n_channels=self.args.n_colors) lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) return lr_t, -1, filename def __len__(self): return len(self.filelist) def set_scale(self, idx_scale): self.idx_scale = idx_scale ================================================ FILE: src/data/div2k.py ================================================ import os from data import srdata class DIV2K(srdata.SRData): def __init__(self, args, name='DIV2K', train=True, benchmark=False): data_range = [r.split('-') for r in args.data_range.split('/')] if train: data_range = data_range[0] else: if args.test_only and len(data_range) == 1: data_range = data_range[0] else: data_range = data_range[1] self.begin, self.end = list(map(lambda x: int(x), data_range)) super(DIV2K, self).__init__( args, name=name, train=train, benchmark=benchmark ) def _scan(self): names_hr, names_lr = super(DIV2K, self)._scan() names_hr = names_hr[self.begin - 1:self.end] names_lr = [n[self.begin - 1:self.end] for n in names_lr] return names_hr, names_lr def _set_filesystem(self, dir_data): super(DIV2K, self)._set_filesystem(dir_data) self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') if self.input_large: self.dir_lr += 'L' ================================================ FILE: src/data/div2kjpeg.py ================================================ import os from data import srdata from data import div2k class DIV2KJPEG(div2k.DIV2K): def __init__(self, args, name='', train=True, benchmark=False): self.q_factor = int(name.replace('DIV2K-Q', '')) super(DIV2KJPEG, self).__init__( args, name=name, train=train, benchmark=benchmark ) def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data, 'DIV2K') self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') self.dir_lr = os.path.join( self.apath, 'DIV2K_Q{}'.format(self.q_factor) ) if self.input_large: self.dir_lr += 'L' self.ext = ('.png', '.jpg') ================================================ FILE: src/data/sr291.py ================================================ from data import srdata class SR291(srdata.SRData): def __init__(self, args, name='SR291', train=True, benchmark=False): super(SR291, self).__init__(args, name=name) ================================================ FILE: src/data/srdata.py ================================================ import os import glob import random import pickle from data import common import numpy as np import imageio import torch import torch.utils.data as data class SRData(data.Dataset): def __init__(self, args, name='', train=True, benchmark=False): self.args = args self.name = name self.train = train self.split = 'train' if train else 'test' self.do_eval = True self.benchmark = benchmark self.input_large = (args.model == 'VDSR') self.scale = args.scale self.idx_scale = 0 self._set_filesystem(args.dir_data) if args.ext.find('img') < 0: path_bin = os.path.join(self.apath, 'bin') os.makedirs(path_bin, exist_ok=True) list_hr, list_lr = self._scan() if args.ext.find('img') >= 0 or benchmark: self.images_hr, self.images_lr = list_hr, list_lr elif args.ext.find('sep') >= 0: os.makedirs( self.dir_hr.replace(self.apath, path_bin), exist_ok=True ) for s in self.scale: os.makedirs( os.path.join( self.dir_lr.replace(self.apath, path_bin), 'X{}'.format(s) ), exist_ok=True ) self.images_hr, self.images_lr = [], [[] for _ in self.scale] for h in list_hr: b = h.replace(self.apath, path_bin) b = b.replace(self.ext[0], '.pt') self.images_hr.append(b) self._check_and_load(args.ext, h, b, verbose=True) for i, ll in enumerate(list_lr): for l in ll: b = l.replace(self.apath, path_bin) b = b.replace(self.ext[1], '.pt') self.images_lr[i].append(b) self._check_and_load(args.ext, l, b, verbose=True) if train: n_patches = args.batch_size * args.test_every n_images = len(args.data_train) * len(self.images_hr) if n_images == 0: self.repeat = 0 else: self.repeat = max(n_patches // n_images, 1) # Below functions as used to prepare images def _scan(self): names_hr = sorted( glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) ) names_lr = [[] for _ in self.scale] for f in names_hr: filename, _ = os.path.splitext(os.path.basename(f)) for si, s in enumerate(self.scale): names_lr[si].append(os.path.join( self.dir_lr, 'X{}/{}x{}{}'.format( s, filename, s, self.ext[1] ) )) return names_hr, names_lr def _set_filesystem(self, dir_data): self.apath = os.path.join(dir_data, self.name) self.dir_hr = os.path.join(self.apath, 'HR') self.dir_lr = os.path.join(self.apath, 'LR_bicubic') if self.input_large: self.dir_lr += 'L' self.ext = ('.png', '.png') def _check_and_load(self, ext, img, f, verbose=True): if not os.path.isfile(f) or ext.find('reset') >= 0: if verbose: print('Making a binary: {}'.format(f)) with open(f, 'wb') as _f: pickle.dump(imageio.imread(img), _f) def __getitem__(self, idx): lr, hr, filename = self._load_file(idx) pair = self.get_patch(lr, hr) pair = common.set_channel(*pair, n_channels=self.args.n_colors) pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) return pair_t[0], pair_t[1], filename def __len__(self): if self.train: return len(self.images_hr) * self.repeat else: return len(self.images_hr) def _get_index(self, idx): if self.train: return idx % len(self.images_hr) else: return idx def _load_file(self, idx): idx = self._get_index(idx) f_hr = self.images_hr[idx] f_lr = self.images_lr[self.idx_scale][idx] filename, _ = os.path.splitext(os.path.basename(f_hr)) if self.args.ext == 'img' or self.benchmark: hr = imageio.imread(f_hr) lr = imageio.imread(f_lr) elif self.args.ext.find('sep') >= 0: with open(f_hr, 'rb') as _f: hr = pickle.load(_f) with open(f_lr, 'rb') as _f: lr = pickle.load(_f) return lr, hr, filename def get_patch(self, lr, hr): scale = self.scale[self.idx_scale] if self.train: lr, hr = common.get_patch( lr, hr, patch_size=self.args.patch_size, scale=scale, multi=(len(self.scale) > 1), input_large=self.input_large ) if not self.args.no_augment: lr, hr = common.augment(lr, hr) else: ih, iw = lr.shape[:2] hr = hr[0:ih * scale, 0:iw * scale] return lr, hr def set_scale(self, idx_scale): if not self.input_large: self.idx_scale = idx_scale else: self.idx_scale = random.randint(0, len(self.scale) - 1) ================================================ FILE: src/data/video.py ================================================ import os from data import common import cv2 import numpy as np import imageio import torch import torch.utils.data as data class Video(data.Dataset): def __init__(self, args, name='Video', train=False, benchmark=False): self.args = args self.name = name self.scale = args.scale self.idx_scale = 0 self.train = False self.do_eval = False self.benchmark = benchmark self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) self.vidcap = cv2.VideoCapture(args.dir_demo) self.n_frames = 0 self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) def __getitem__(self, idx): success, lr = self.vidcap.read() if success: self.n_frames += 1 lr, = common.set_channel(lr, n_channels=self.args.n_colors) lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames) else: vidcap.release() return None def __len__(self): return self.total_frames def set_scale(self, idx_scale): self.idx_scale = idx_scale ================================================ FILE: src/dataloader.py ================================================ import threading import random import torch import torch.multiprocessing as multiprocessing from torch.utils.data import DataLoader from torch.utils.data import SequentialSampler from torch.utils.data import RandomSampler from torch.utils.data import BatchSampler from torch.utils.data import _utils from torch.utils.data.dataloader import _DataLoaderIter from torch.utils.data._utils import collate from torch.utils.data._utils import signal_handling from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data._utils import ExceptionWrapper from torch.utils.data._utils import IS_WINDOWS from torch.utils.data._utils.worker import ManagerWatchdog from torch._six import queue def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): try: collate._use_shared_memory = True signal_handling._set_worker_signal_handlers() torch.set_num_threads(1) random.seed(seed) torch.manual_seed(seed) data_queue.cancel_join_thread() if init_fn is not None: init_fn(worker_id) watchdog = ManagerWatchdog() while watchdog.is_alive(): try: r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) except queue.Empty: continue if r is None: assert done_event.is_set() return elif done_event.is_set(): continue idx, batch_indices = r try: idx_scale = 0 if len(scale) > 1 and dataset.train: idx_scale = random.randrange(0, len(scale)) dataset.set_scale(idx_scale) samples = collate_fn([dataset[i] for i in batch_indices]) samples.append(idx_scale) except Exception: data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) else: data_queue.put((idx, samples)) del samples except KeyboardInterrupt: pass class _MSDataLoaderIter(_DataLoaderIter): def __init__(self, loader): self.dataset = loader.dataset self.scale = loader.scale self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.Queue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.done_event = multiprocessing.Event() base_seed = torch.LongTensor(1).random_()[0] self.index_queues = [] self.workers = [] for i in range(self.num_workers): index_queue = multiprocessing.Queue() index_queue.cancel_join_thread() w = multiprocessing.Process( target=_ms_loop, args=( self.dataset, index_queue, self.worker_result_queue, self.done_event, self.collate_fn, self.scale, base_seed + i, self.worker_init_fn, i ) ) w.daemon = True w.start() self.index_queues.append(index_queue) self.workers.append(w) if self.pin_memory: self.data_queue = queue.Queue() pin_memory_thread = threading.Thread( target=_utils.pin_memory._pin_memory_loop, args=( self.worker_result_queue, self.data_queue, torch.cuda.current_device(), self.done_event ) ) pin_memory_thread.daemon = True pin_memory_thread.start() self.pin_memory_thread = pin_memory_thread else: self.data_queue = self.worker_result_queue _utils.signal_handling._set_worker_pids( id(self), tuple(w.pid for w in self.workers) ) _utils.signal_handling._set_SIGCHLD_handler() self.worker_pids_set = True for _ in range(2 * self.num_workers): self._put_indices() class MSDataLoader(DataLoader): def __init__(self, cfg, *args, **kwargs): super(MSDataLoader, self).__init__( *args, **kwargs, num_workers=cfg.n_threads ) self.scale = cfg.scale def __iter__(self): return _MSDataLoaderIter(self) ================================================ FILE: src/demo.sh ================================================ #!/bin/bash #Train x2 python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K #Test x2 python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only ================================================ FILE: src/loss/__init__.py ================================================ import os from importlib import import_module import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Loss(nn.modules.loss._Loss): def __init__(self, args, ckp): super(Loss, self).__init__() print('Preparing loss function:') self.n_GPUs = args.n_GPUs self.loss = [] self.loss_module = nn.ModuleList() for loss in args.loss.split('+'): weight, loss_type = loss.split('*') if loss_type == 'MSE': loss_function = nn.MSELoss() elif loss_type == 'L1': loss_function = nn.L1Loss() elif loss_type.find('VGG') >= 0: module = import_module('loss.vgg') loss_function = getattr(module, 'VGG')( loss_type[3:], rgb_range=args.rgb_range ) elif loss_type.find('GAN') >= 0: module = import_module('loss.adversarial') loss_function = getattr(module, 'Adversarial')( args, loss_type ) self.loss.append({ 'type': loss_type, 'weight': float(weight), 'function': loss_function} ) if loss_type.find('GAN') >= 0: self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) if len(self.loss) > 1: self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) for l in self.loss: if l['function'] is not None: print('{:.3f} * {}'.format(l['weight'], l['type'])) self.loss_module.append(l['function']) self.log = torch.Tensor() device = torch.device('cpu' if args.cpu else 'cuda') self.loss_module.to(device) if args.precision == 'half': self.loss_module.half() if not args.cpu and args.n_GPUs > 1: self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs)) if args.load != '': self.load(ckp.dir, cpu=args.cpu) def forward(self, sr, hr): losses = [] for i, l in enumerate(self.loss): if l['function'] is not None: loss = l['function'](sr, hr) effective_loss = l['weight'] * loss losses.append(effective_loss) self.log[-1, i] += effective_loss.item() elif l['type'] == 'DIS': self.log[-1, i] += self.loss[i - 1]['function'].loss loss_sum = sum(losses) if len(self.loss) > 1: self.log[-1, -1] += loss_sum.item() return loss_sum def step(self): for l in self.get_loss_module(): if hasattr(l, 'scheduler'): l.scheduler.step() def start_log(self): self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) def end_log(self, n_batches): self.log[-1].div_(n_batches) def display_loss(self, batch): n_samples = batch + 1 log = [] for l, c in zip(self.loss, self.log[-1]): log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) return ''.join(log) def plot_loss(self, apath, epoch): axis = np.linspace(1, epoch, epoch) for i, l in enumerate(self.loss): label = '{} Loss'.format(l['type']) fig = plt.figure() plt.title(label) plt.plot(axis, self.log[:, i].numpy(), label=label) plt.legend() plt.xlabel('Epochs') plt.ylabel('Loss') plt.grid(True) plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type']))) plt.close(fig) def get_loss_module(self): if self.n_GPUs == 1: return self.loss_module else: return self.loss_module.module def save(self, apath): torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) torch.save(self.log, os.path.join(apath, 'loss_log.pt')) def load(self, apath, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} self.load_state_dict(torch.load( os.path.join(apath, 'loss.pt'), **kwargs )) self.log = torch.load(os.path.join(apath, 'loss_log.pt')) for l in self.get_loss_module(): if hasattr(l, 'scheduler'): for _ in range(len(self.log)): l.scheduler.step() ================================================ FILE: src/loss/__loss__.py ================================================ ================================================ FILE: src/loss/adversarial.py ================================================ import utility from types import SimpleNamespace from model import common from loss import discriminator import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim class Adversarial(nn.Module): def __init__(self, args, gan_type): super(Adversarial, self).__init__() self.gan_type = gan_type self.gan_k = args.gan_k self.dis = discriminator.Discriminator(args) if gan_type == 'WGAN_GP': # see https://arxiv.org/pdf/1704.00028.pdf pp.4 optim_dict = { 'optimizer': 'ADAM', 'betas': (0, 0.9), 'epsilon': 1e-8, 'lr': 1e-5, 'weight_decay': args.weight_decay, 'decay': args.decay, 'gamma': args.gamma } optim_args = SimpleNamespace(**optim_dict) else: optim_args = args self.optimizer = utility.make_optimizer(optim_args, self.dis) def forward(self, fake, real): # updating discriminator... self.loss = 0 fake_detach = fake.detach() # do not backpropagate through G for _ in range(self.gan_k): self.optimizer.zero_grad() # d: B x 1 tensor d_fake = self.dis(fake_detach) d_real = self.dis(real) retain_graph = False if self.gan_type == 'GAN': loss_d = self.bce(d_real, d_fake) elif self.gan_type.find('WGAN') >= 0: loss_d = (d_fake - d_real).mean() if self.gan_type.find('GP') >= 0: epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) hat.requires_grad = True d_hat = self.dis(hat) gradients = torch.autograd.grad( outputs=d_hat.sum(), inputs=hat, retain_graph=True, create_graph=True, only_inputs=True )[0] gradients = gradients.view(gradients.size(0), -1) gradient_norm = gradients.norm(2, dim=1) gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() loss_d += gradient_penalty # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks elif self.gan_type == 'RGAN': better_real = d_real - d_fake.mean(dim=0, keepdim=True) better_fake = d_fake - d_real.mean(dim=0, keepdim=True) loss_d = self.bce(better_real, better_fake) retain_graph = True # Discriminator update self.loss += loss_d.item() loss_d.backward(retain_graph=retain_graph) self.optimizer.step() if self.gan_type == 'WGAN': for p in self.dis.parameters(): p.data.clamp_(-1, 1) self.loss /= self.gan_k # updating generator... d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is if self.gan_type == 'GAN': label_real = torch.ones_like(d_fake_bp) loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real) elif self.gan_type.find('WGAN') >= 0: loss_g = -d_fake_bp.mean() elif self.gan_type == 'RGAN': better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True) better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True) loss_g = self.bce(better_fake, better_real) # Generator loss return loss_g def state_dict(self, *args, **kwargs): state_discriminator = self.dis.state_dict(*args, **kwargs) state_optimizer = self.optimizer.state_dict() return dict(**state_discriminator, **state_optimizer) def bce(self, real, fake): label_real = torch.ones_like(real) label_fake = torch.zeros_like(fake) bce_real = F.binary_cross_entropy_with_logits(real, label_real) bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake) bce_loss = bce_real + bce_fake return bce_loss # Some references # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py # OR # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py ================================================ FILE: src/loss/demo.sh ================================================ ================================================ FILE: src/loss/discriminator.py ================================================ from model import common import torch.nn as nn class Discriminator(nn.Module): ''' output is not normalized ''' def __init__(self, args): super(Discriminator, self).__init__() in_channels = args.n_colors out_channels = 64 depth = 7 def _block(_in_channels, _out_channels, stride=1): return nn.Sequential( nn.Conv2d( _in_channels, _out_channels, 3, padding=1, stride=stride, bias=False ), nn.BatchNorm2d(_out_channels), nn.LeakyReLU(negative_slope=0.2, inplace=True) ) m_features = [_block(in_channels, out_channels)] for i in range(depth): in_channels = out_channels if i % 2 == 1: stride = 1 out_channels *= 2 else: stride = 2 m_features.append(_block(in_channels, out_channels, stride=stride)) patch_size = args.patch_size // (2**((depth + 1) // 2)) m_classifier = [ nn.Linear(out_channels * patch_size**2, 1024), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Linear(1024, 1) ] self.features = nn.Sequential(*m_features) self.classifier = nn.Sequential(*m_classifier) def forward(self, x): features = self.features(x) output = self.classifier(features.view(features.size(0), -1)) return output ================================================ FILE: src/loss/hash.py ================================================ from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class HASH(nn.Module): def __init__(self): super(HASH, self).__init__() self.l1 = nn.L1Loss() def forward(self, sr, qk, orders, hr, m=3): #hash loss qk = F.normalize(qk, p=2, dim=1, eps=5e-5) N,C,H,W = qk.shape qk = qk.view(N,C,H*W) qk_t = qk.permute(0,2,1).contiguous() similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W] orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W] orders_t = torch.transpose(orders,1,2) dist = torch.pow(orders-orders_t,2) ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1)) ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1)) loss = 0.005*(ls+ld)+self.l1(sr,hr) return loss ================================================ FILE: src/loss/vgg.py ================================================ from model import common import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models class VGG(nn.Module): def __init__(self, conv_index, rgb_range=1): super(VGG, self).__init__() vgg_features = models.vgg19(pretrained=True).features modules = [m for m in vgg_features] if conv_index.find('22') >= 0: self.vgg = nn.Sequential(*modules[:8]) elif conv_index.find('54') >= 0: self.vgg = nn.Sequential(*modules[:35]) vgg_mean = (0.485, 0.456, 0.406) vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) for p in self.parameters(): p.requires_grad = False def forward(self, sr, hr): def _forward(x): x = self.sub_mean(x) x = self.vgg(x) return x vgg_sr = _forward(sr) with torch.no_grad(): vgg_hr = _forward(hr.detach()) loss = F.mse_loss(vgg_sr, vgg_hr) return loss ================================================ FILE: src/main.py ================================================ import torch import utility import data import model import loss from option import args from trainer import Trainer torch.manual_seed(args.seed) checkpoint = utility.checkpoint(args) def main(): global model if args.data_test == ['video']: from videotester import VideoTester model = model.Model(args, checkpoint) print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) t = VideoTester(args, model, checkpoint) t.test() else: if checkpoint.ok: loader = data.Data(args) _model = model.Model(args, checkpoint) print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0)) _loss = loss.Loss(args, checkpoint) if not args.test_only else None t = Trainer(args, loader, _model, _loss, checkpoint) while not t.terminate(): t.train() t.test() checkpoint.done() if __name__ == '__main__': main() ================================================ FILE: src/model/LICENSE ================================================ MIT License Copyright (c) 2018 Sanghyun Son Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: src/model/README.md ================================================ # EDSR-PyTorch ![](/figs/main.png) This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**. You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017). If you find our work useful in your research or publication, please cite our work: [1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)] ``` @InProceedings{Lim_2017_CVPR_Workshops, author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu}, title = {Enhanced Deep Residual Networks for Single Image Super-Resolution}, booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, month = {July}, year = {2017} } ``` We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images. **Differences between Torch version** * Codes are much more compact. (Removed all unnecessary parts.) * Models are smaller. (About half.) * Slightly better performances. * Training and evaluation requires less memory. * Python-based. ## Dependencies * Python 3.6 * PyTorch >= 0.4.0 * numpy * skimage * **imageio** * matplotlib * tqdm **Recent updates** * July 22, 2018 * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models. * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!). ## Code Clone this repository into any place you want. ```bash git clone https://github.com/thstkdgus35/EDSR-PyTorch cd EDSR-PyTorch ``` ## Quick start (Demo) You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/``) We support **png** and **jpeg** files. Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute. ```bash cd src # You are now in */EDSR-PyTorch/src sh demo.sh ``` You can find the result images from ```experiment/test/results``` folder. | Model | Scale | File name (.pt) | Parameters | ****PSNR** | | --- | --- | --- | --- | --- | | **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB | | | | *EDSR_x2 | 40.7 M | 35.03 dB | | | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB | | | | *EDSR_x3 | 43.7 M | 31.26 dB | | | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB | | | | *EDSR_x4 | 43.1 M | 29.25 dB | | **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB | | | | *MDSR | 7.95 M| 34.92 dB | | | 3 | MDSR_baseline | | 30.94 dB | | | | *MDSR | | 31.22 dB | | | 4 | MDSR_baseline | | 28.97 dB | | | | *MDSR | | 29.24 dB | *Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB) **We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored. You can evaluate your models with widely-used benchmark datasets: [Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html), [Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests), [B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/), [Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr). For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data `` to evaluate the EDSR and MDSR with the benchmarks. ## How to train EDSR and MDSR We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB). Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located. We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument. If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file. You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train ``` argument. ```bash cd src # You are now in */EDSR-PyTorch/src sh demo.sh ``` **Update log** * Jan 04, 2018 * Many parts are re-written. You cannot use previous scripts and models directly. * Pre-trained MDSR is temporarily disabled. * Training details are included. * Jan 09, 2018 * Missing files are included (```src/data/MyImage.py```). * Some links are fixed. * Jan 16, 2018 * Memory efficient forward function is implemented. * Add --chop_forward argument to your script to enable it. * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.) * Feb 21, 2018 * Fixed the problem when loading pre-trained multi-gpu model. * Added pre-trained scale 2 baseline model. * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models. * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch. * Feb 23, 2018 * Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version. * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution. * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.) * With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.) * If you cannot make the binary pack, just use the default setting (``--ext img``). * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match. * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.) * Added performance comparison between Torch7 model and PyTorch models. * Mar 5, 2018 * All baseline models are uploaded. * Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images. * Mar 11, 2018 * Fixed some typos in the code and script. * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only. * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected. * Mar 20, 2018 * Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time. * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images. * Changed the behavior of skip_batch. * Mar 29, 2018 * We now provide all models from our paper. * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble. * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before. * Some codes and script are re-written. * Apr 9, 2018 * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet. * Many codes are refactored. If there exists a bug, please report it. * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L. * Apr 26, 2018 * Compatible with PyTorch 0.4.0 * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch. * Minor bug fixes ================================================ FILE: src/model/__init__.py ================================================ import os from importlib import import_module import torch import torch.nn as nn from torch.autograd import Variable class Model(nn.Module): def __init__(self, args, ckp): super(Model, self).__init__() print('Making model...') self.scale = args.scale self.idx_scale = 0 self.self_ensemble = args.self_ensemble self.chop = args.chop self.precision = args.precision self.cpu = args.cpu self.device = torch.device('cpu' if args.cpu else 'cuda') self.n_GPUs = args.n_GPUs self.save_models = args.save_models module = import_module('model.' + args.model.lower()) self.model = module.make_model(args).to(self.device) if args.precision == 'half': self.model.half() if not args.cpu and args.n_GPUs > 1: self.model = nn.DataParallel(self.model, range(args.n_GPUs)) self.load( ckp.dir, pre_train=args.pre_train, resume=args.resume, cpu=args.cpu ) print(self.model, file=ckp.log_file) def forward(self, x, idx_scale): self.idx_scale = idx_scale target = self.get_model() if hasattr(target, 'set_scale'): target.set_scale(idx_scale) if self.self_ensemble and not self.training: if self.chop: forward_function = self.forward_chop else: forward_function = self.model.forward return self.forward_x8(x, forward_function) elif self.chop and not self.training: return self.forward_chop(x) else: return self.model(x) def get_model(self): if self.n_GPUs == 1: return self.model else: return self.model.module def state_dict(self, **kwargs): target = self.get_model() return target.state_dict(**kwargs) def save(self, apath, epoch, is_best=False): target = self.get_model() torch.save( target.state_dict(), os.path.join(apath, 'model_latest.pt') ) if is_best: torch.save( target.state_dict(), os.path.join(apath, 'model_best.pt') ) if self.save_models: torch.save( target.state_dict(), os.path.join(apath, 'model_{}.pt'.format(epoch)) ) def load(self, apath, pre_train='.', resume=-1, cpu=False): if cpu: kwargs = {'map_location': lambda storage, loc: storage} else: kwargs = {} if resume == -1: self.get_model().load_state_dict( torch.load( os.path.join(apath, 'model_latest.pt'), **kwargs ), strict=False ) elif resume == 0: if pre_train != '.': print('Loading model from {}'.format(pre_train)) self.get_model().load_state_dict( torch.load(pre_train, **kwargs), strict=False ) else: self.get_model().load_state_dict( torch.load( os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs ), strict=False ) def forward_chop(self, x, shave=10, min_size=120000): scale = self.scale[self.idx_scale] n_GPUs = min(self.n_GPUs, 4) b, c, h, w = x.size() h_half, w_half = h // 2, w // 2 h_size, w_size = h_half + shave, w_half + shave h_size +=4-h_size%4 w_size +=8-w_size%8 lr_list = [ x[:, :, 0:h_size, 0:w_size], x[:, :, 0:h_size, (w - w_size):w], x[:, :, (h - h_size):h, 0:w_size], x[:, :, (h - h_size):h, (w - w_size):w]] if w_size * h_size < min_size: sr_list = [] for i in range(0, 4, n_GPUs): lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) sr_batch = self.model(lr_batch) sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) else: sr_list = [ self.forward_chop(patch, shave=shave, min_size=min_size) \ for patch in lr_list ] h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half h_size, w_size = scale * h_size, scale * w_size shave *= scale output = x.new(b, c, h, w) output[:, :, 0:h_half, 0:w_half] \ = sr_list[0][:, :, 0:h_half, 0:w_half] output[:, :, 0:h_half, w_half:w] \ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] output[:, :, h_half:h, 0:w_half] \ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] output[:, :, h_half:h, w_half:w] \ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output def forward_x8(self, x, forward_function): def _transform(v, op): if self.precision != 'single': v = v.float() v2np = v.data.cpu().numpy() if op == 'v': tfnp = v2np[:, :, :, ::-1].copy() elif op == 'h': tfnp = v2np[:, :, ::-1, :].copy() elif op == 't': tfnp = v2np.transpose((0, 1, 3, 2)).copy() ret = torch.Tensor(tfnp).to(self.device) if self.precision == 'half': ret = ret.half() return ret lr_list = [x] for tf in 'v', 'h', 't': lr_list.extend([_transform(t, tf) for t in lr_list]) sr_list = [forward_function(aug) for aug in lr_list] for i in range(len(sr_list)): if i > 3: sr_list[i] = _transform(sr_list[i], 't') if i % 4 > 1: sr_list[i] = _transform(sr_list[i], 'h') if (i % 4) % 2 == 1: sr_list[i] = _transform(sr_list[i], 'v') output_cat = torch.cat(sr_list, dim=0) output = output_cat.mean(dim=0, keepdim=True) return output ================================================ FILE: src/model/attention.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from model import common class NonLocalSparseAttention(nn.Module): def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1): super(NonLocalSparseAttention,self).__init__() self.chunk_size = chunk_size self.n_hashes = n_hashes self.reduction = reduction self.res_scale = res_scale self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None) self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None) def LSH(self, hash_buckets, x): #x: [N,H*W,C] N = x.shape[0] device = x.device #generate random rotation matrix rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2] random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2] #locality sensitive hashing rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2] rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets] #get hash codes hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W] #add offsets to avoid hash codes overlapping between hash rounds offsets = torch.arange(self.n_hashes, device=device) offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1)) hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W] return hash_codes def add_adjacent_buckets(self, x): x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2) x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2) return torch.cat([x, x_extra_back,x_extra_forward], dim=3) def forward(self, input): N,_,H,W = input.shape x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1) y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1) L,C = x_embed.shape[-2:] #number of hash buckets/hash bits hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128) #get assigned hash codes/bucket number hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W] hash_codes = hash_codes.detach() #group elements with same hash code by sorting _, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W] _, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order mod_indices = (indices % L) #now range from (0->H*W) x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C] y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C] #pad the embedding if it cannot be divided by chunk_size padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0 x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C] y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction)) if padding: pad_x = x_att_buckets[:,:,-padding:,:].clone() pad_y = y_att_buckets[:,:,-padding:,:].clone() x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2) y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2) x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C] y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction)) x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5) #allow attend to adjacent buckets x_match = self.add_adjacent_buckets(x_match) y_att_buckets = self.add_adjacent_buckets(y_att_buckets) #unormalized attention score raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3] #softmax bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True) score = torch.exp(raw_score - bucket_score) #(after softmax) bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1]) #attention ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C] ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction)) #if padded, then remove extra elements if padding: ret = ret[:,:,:-padding,:].clone() bucket_score = bucket_score[:,:,:-padding].clone() #recover the original order ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C] bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W] ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C] bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W] #weighted sum multi-round attention ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C] bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1)) probs = nn.functional.softmax(bucket_score,dim=1) ret = torch.sum(ret * probs, dim=1) ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input return ret class NonLocalAttention(nn.Module): def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv): super(NonLocalAttention, self).__init__() self.res_scale = res_scale self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU()) self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU()) self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU()) def forward(self, input): x_embed_1 = self.conv_match1(input) x_embed_2 = self.conv_match2(input) x_assembly = self.conv_assembly(input) N,C,H,W = x_embed_1.shape x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C)) x_embed_2 = x_embed_2.view(N,C,H*W) score = torch.matmul(x_embed_1, x_embed_2) score = F.softmax(score, dim=2) x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1) x_final = torch.matmul(score, x_assembly) return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input ================================================ FILE: src/model/common.py ================================================ import math import torch import torch.nn as nn import torch.nn.functional as F def batched_index_select(values, indices): last_dim = values.shape[-1] return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim)) def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size//2),stride=stride, bias=bias) class MeanShift(nn.Conv2d): def __init__( self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): super(MeanShift, self).__init__(3, 3, kernel_size=1) std = torch.Tensor(rgb_std) self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std for p in self.parameters(): p.requires_grad = False class BasicBlock(nn.Sequential): def __init__( self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True, bn=False, act=nn.PReLU()): m = [conv(in_channels, out_channels, kernel_size, bias=bias)] if bn: m.append(nn.BatchNorm2d(out_channels)) if act is not None: m.append(act) super(BasicBlock, self).__init__(*m) class ResBlock(nn.Module): def __init__( self, conv, n_feats, kernel_size, bias=True, bn=False, act=nn.PReLU(), res_scale=1): super(ResBlock, self).__init__() m = [] for i in range(2): m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) if bn: m.append(nn.BatchNorm2d(n_feats)) if i == 0: m.append(act) self.body = nn.Sequential(*m) self.res_scale = res_scale def forward(self, x): res = self.body(x).mul(self.res_scale) res += x return res class Upsampler(nn.Sequential): def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): m = [] if (scale & (scale - 1)) == 0: # Is scale = 2^n? for _ in range(int(math.log(scale, 2))): m.append(conv(n_feats, 4 * n_feats, 3, bias=bias)) m.append(nn.PixelShuffle(2)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) elif scale == 3: m.append(conv(n_feats, 9 * n_feats, 3, bias=bias)) m.append(nn.PixelShuffle(3)) if bn: m.append(nn.BatchNorm2d(n_feats)) if act == 'relu': m.append(nn.ReLU(True)) elif act == 'prelu': m.append(nn.PReLU(n_feats)) else: raise NotImplementedError super(Upsampler, self).__init__(*m) ================================================ FILE: src/model/ddbpn.py ================================================ # Deep Back-Projection Networks For Super-Resolution # https://arxiv.org/abs/1803.02735 from model import common import torch import torch.nn as nn def make_model(args, parent=False): return DDBPN(args) def projection_conv(in_channels, out_channels, scale, up=True): kernel_size, stride, padding = { 2: (6, 2, 2), 4: (8, 4, 2), 8: (12, 8, 2) }[scale] if up: conv_f = nn.ConvTranspose2d else: conv_f = nn.Conv2d return conv_f( in_channels, out_channels, kernel_size, stride=stride, padding=padding ) class DenseProjection(nn.Module): def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): super(DenseProjection, self).__init__() if bottleneck: self.bottleneck = nn.Sequential(*[ nn.Conv2d(in_channels, nr, 1), nn.PReLU(nr) ]) inter_channels = nr else: self.bottleneck = None inter_channels = in_channels self.conv_1 = nn.Sequential(*[ projection_conv(inter_channels, nr, scale, up), nn.PReLU(nr) ]) self.conv_2 = nn.Sequential(*[ projection_conv(nr, inter_channels, scale, not up), nn.PReLU(inter_channels) ]) self.conv_3 = nn.Sequential(*[ projection_conv(inter_channels, nr, scale, up), nn.PReLU(nr) ]) def forward(self, x): if self.bottleneck is not None: x = self.bottleneck(x) a_0 = self.conv_1(x) b_0 = self.conv_2(a_0) e = b_0.sub(x) a_1 = self.conv_3(e) out = a_0.add(a_1) return out class DDBPN(nn.Module): def __init__(self, args): super(DDBPN, self).__init__() scale = args.scale[0] n0 = 128 nr = 32 self.depth = 6 rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) initial = [ nn.Conv2d(args.n_colors, n0, 3, padding=1), nn.PReLU(n0), nn.Conv2d(n0, nr, 1), nn.PReLU(nr) ] self.initial = nn.Sequential(*initial) self.upmodules = nn.ModuleList() self.downmodules = nn.ModuleList() channels = nr for i in range(self.depth): self.upmodules.append( DenseProjection(channels, nr, scale, True, i > 1) ) if i != 0: channels += nr channels = nr for i in range(self.depth - 1): self.downmodules.append( DenseProjection(channels, nr, scale, False, i != 0) ) channels += nr reconstruction = [ nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) ] self.reconstruction = nn.Sequential(*reconstruction) self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) def forward(self, x): x = self.sub_mean(x) x = self.initial(x) h_list = [] l_list = [] for i in range(self.depth - 1): if i == 0: l = x else: l = torch.cat(l_list, dim=1) h_list.append(self.upmodules[i](l)) l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) out = self.reconstruction(torch.cat(h_list, dim=1)) out = self.add_mean(out) return out ================================================ FILE: src/model/edsr.py ================================================ from model import common from model import attention import torch.nn as nn def make_model(args, parent=False): if args.dilation: from model import dilated return EDSR(args, dilated.dilated_conv) else: return EDSR(args) class EDSR(nn.Module): def __init__(self, args, conv=common.default_conv): super(EDSR, self).__init__() n_resblock = args.n_resblocks n_feats = args.n_feats kernel_size = 3 scale = args.scale[0] act = nn.ReLU(True) rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) #self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale); # define head module m_head = [conv(args.n_colors, n_feats, kernel_size)] # define body module m_body = [ common.ResBlock( conv, n_feats, kernel_size, act=act, res_scale=args.res_scale ) for _ in range(n_resblock//2) ] #m_body.append(self.msa) for _ in range(n_resblock//2): m_body.append( common.ResBlock( conv, n_feats, kernel_size, act=act, res_scale=args.res_scale )) m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ common.Upsampler(conv, scale, n_feats, act=False), nn.Conv2d( n_feats, args.n_colors, kernel_size, padding=(kernel_size//2) ) ] self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') == -1: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) ================================================ FILE: src/model/mdsr.py ================================================ from model import common import torch.nn as nn def make_model(args, parent=False): return MDSR(args) class MDSR(nn.Module): def __init__(self, args, conv=common.default_conv): super(MDSR, self).__init__() n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 self.scale_idx = 0 act = nn.ReLU(True) rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) m_head = [conv(args.n_colors, n_feats, kernel_size)] self.pre_process = nn.ModuleList([ nn.Sequential( common.ResBlock(conv, n_feats, 5, act=act), common.ResBlock(conv, n_feats, 5, act=act) ) for _ in args.scale ]) m_body = [ common.ResBlock( conv, n_feats, kernel_size, act=act ) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) self.upsample = nn.ModuleList([ common.Upsampler( conv, s, n_feats, act=False ) for s in args.scale ]) m_tail = [conv(n_feats, args.n_colors, kernel_size)] self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) x = self.pre_process[self.scale_idx](x) res = self.body(x) res += x x = self.upsample[self.scale_idx](res) x = self.tail(x) x = self.add_mean(x) return x def set_scale(self, scale_idx): self.scale_idx = scale_idx ================================================ FILE: src/model/mssr.py ================================================ from model import common import torch.nn as nn import torch from model.attention import ContextualAttention,NonLocalAttention def make_model(args, parent=False): return MSSR(args) class MultisourceProjection(nn.Module): def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv): super(MultisourceProjection, self).__init__() self.up_attention = ContextualAttention(scale=2) self.down_attention = NonLocalAttention() self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1) def forward(self,x): down_map = self.upsample(self.down_attention(x)) up_map = self.up_attention(x) err = self.encoder(up_map-down_map) final_map = down_map + err return final_map class RecurrentProjection(nn.Module): def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv): super(RecurrentProjection, self).__init__() self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv) self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv) self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) #self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()]) self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()]) self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU()) def forward(self, x): x_up = self.multi_source_projection_1(x) x_down = self.down_sample_1(x_up) error_up = self.error_encode_1(x-x_down) h_estimate_1 = x_up + error_up x_up_2 = self.multi_source_projection_2(h_estimate_1) x_down_2 = self.down_sample_3(x_up_2) error_up_2 = self.error_encode_2(x-x_down_2) h_estimate_2 = x_up_2 + error_up_2 x_final = self.post_conv(self.down_sample_4(h_estimate_2)) return x_final, h_estimate_2 class MSSR(nn.Module): def __init__(self, args, conv=common.default_conv): super(MSSR, self).__init__() #n_convblock = args.n_convblocks n_feats = args.n_feats self.depth = args.depth kernel_size = 3 scale = args.scale[0] rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) # define head module m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()), common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())] # define multiple reconstruction module self.body = RecurrentProjection(n_feats) # define tail module m_tail = [ nn.Conv2d( n_feats*self.depth, args.n_colors, kernel_size, padding=(kernel_size//2) ) ] self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.tail = nn.Sequential(*m_tail) def forward(self,input): x = self.sub_mean(input) x = self.head(x) bag = [] for i in range(self.depth): x, h_estimate = self.body(x) bag.append(h_estimate) h_feature = torch.cat(bag,dim=1) h_final = self.tail(h_feature) return self.add_mean(h_final) ================================================ FILE: src/model/nlsn.py ================================================ from model import common from model import attention import torch.nn as nn def make_model(args, parent=False): if args.dilation: from model import dilated return NLSN(args, dilated.dilated_conv) else: return NLSN(args) class NLSN(nn.Module): def __init__(self, args, conv=common.default_conv): super(NLSN, self).__init__() n_resblock = args.n_resblocks n_feats = args.n_feats kernel_size = 3 scale = args.scale[0] act = nn.ReLU(True) rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) m_head = [conv(args.n_colors, n_feats, kernel_size)] # define body module m_body = [attention.NonLocalSparseAttention( channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)] for i in range(n_resblock): m_body.append( common.ResBlock( conv, n_feats, kernel_size, act=act, res_scale=args.res_scale )) if (i+1)%8==0: m_body.append(attention.NonLocalSparseAttention( channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)) m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ common.Upsampler(conv, scale, n_feats, act=False), nn.Conv2d( n_feats, args.n_colors, kernel_size, padding=(kernel_size//2) ) ] self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=True): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('tail') == -1: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('tail') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) ================================================ FILE: src/model/rcan.py ================================================ ## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks ## https://arxiv.org/abs/1807.02758 from model import common import torch.nn as nn import torch def make_model(args, parent=False): return RCAN(args) ## Channel Attention (CA) Layer class CALayer(nn.Module): def __init__(self, channel, reduction=16): super(CALayer, self).__init__() # global average pooling: feature --> point self.avg_pool = nn.AdaptiveAvgPool2d(1) # feature channel downscale and upscale --> channel weight #self.a = torch.nn.Parameter(torch.Tensor([0])) #self.a.requires_grad=True self.conv_du = nn.Sequential( nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), nn.ReLU(inplace=True), nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x): y = self.avg_pool(x) y = self.conv_du(y) return x * y ## Residual Channel Attention Block (RCAB) class RCAB(nn.Module): def __init__( self, conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1): super(RCAB, self).__init__() modules_body = [] for i in range(2): modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) if bn: modules_body.append(nn.BatchNorm2d(n_feat)) if i == 0: modules_body.append(act) modules_body.append(CALayer(n_feat, reduction)) self.body = nn.Sequential(*modules_body) self.res_scale = res_scale def forward(self, x): res = self.body(x) #res = self.body(x).mul(self.res_scale) res += x return res ## Residual Group (RG) class ResidualGroup(nn.Module): def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): super(ResidualGroup, self).__init__() modules_body = [] modules_body = [ RCAB( conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ for _ in range(n_resblocks)] modules_body.append(conv(n_feat, n_feat, kernel_size)) self.body = nn.Sequential(*modules_body) def forward(self, x): res = self.body(x) res += x return res ## Residual Channel Attention Network (RCAN) class RCAN(nn.Module): def __init__(self, args, conv=common.default_conv): super(RCAN, self).__init__() self.a = nn.Parameter(torch.Tensor([0])) self.a.requires_grad=True n_resgroups = args.n_resgroups n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 reduction = args.reduction scale = args.scale[0] act = nn.ReLU(True) # RGB mean for DIV2K rgb_mean = (0.4488, 0.4371, 0.4040) rgb_std = (1.0, 1.0, 1.0) self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) # define head module modules_head = [conv(args.n_colors, n_feats, kernel_size)] # define body module modules_body = [ ResidualGroup( conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ for _ in range(n_resgroups)] modules_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module modules_tail = [ common.Upsampler(conv, scale, n_feats, act=False), conv(n_feats, args.n_colors, kernel_size)] self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) self.head = nn.Sequential(*modules_head) self.body = nn.Sequential(*modules_body) self.tail = nn.Sequential(*modules_tail) def forward(self, x): x = self.sub_mean(x) x = self.head(x) res = self.body(x) res += x x = self.tail(res) x = self.add_mean(x) return x def load_state_dict(self, state_dict, strict=False): own_state = self.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): param = param.data try: own_state[name].copy_(param) except Exception: if name.find('msa') or name.find('a') >= 0: print('Replace pre-trained upsampler to new one...') else: raise RuntimeError('While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.' .format(name, own_state[name].size(), param.size())) elif strict: if name.find('msa') == -1: raise KeyError('unexpected key "{}" in state_dict' .format(name)) if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: raise KeyError('missing keys in state_dict: "{}"'.format(missing)) ================================================ FILE: src/model/rdn.py ================================================ # Residual Dense Network for Image Super-Resolution # https://arxiv.org/abs/1802.08797 from model import common import torch import torch.nn as nn def make_model(args, parent=False): return RDN(args) class RDB_Conv(nn.Module): def __init__(self, inChannels, growRate, kSize=3): super(RDB_Conv, self).__init__() Cin = inChannels G = growRate self.conv = nn.Sequential(*[ nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), nn.ReLU() ]) def forward(self, x): out = self.conv(x) return torch.cat((x, out), 1) class RDB(nn.Module): def __init__(self, growRate0, growRate, nConvLayers, kSize=3): super(RDB, self).__init__() G0 = growRate0 G = growRate C = nConvLayers convs = [] for c in range(C): convs.append(RDB_Conv(G0 + c*G, G)) self.convs = nn.Sequential(*convs) # Local Feature Fusion self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) def forward(self, x): return self.LFF(self.convs(x)) + x class RDN(nn.Module): def __init__(self, args): super(RDN, self).__init__() r = args.scale[0] G0 = args.G0 kSize = args.RDNkSize # number of RDB blocks, conv layers, out channels self.D, C, G = { 'A': (20, 6, 32), 'B': (16, 8, 64), }[args.RDNconfig] # Shallow feature extraction net self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) # Redidual dense blocks and dense feature fusion self.RDBs = nn.ModuleList() for i in range(self.D): self.RDBs.append( RDB(growRate0 = G0, growRate = G, nConvLayers = C) ) # Global Feature Fusion self.GFF = nn.Sequential(*[ nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) ]) # Up-sampling net if r == 2 or r == 3: self.UPNet = nn.Sequential(*[ nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), nn.PixelShuffle(r), nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) ]) elif r == 4: self.UPNet = nn.Sequential(*[ nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), nn.PixelShuffle(2), nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), nn.PixelShuffle(2), nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) ]) else: raise ValueError("scale must be 2 or 3 or 4.") def forward(self, x): f__1 = self.SFENet1(x) x = self.SFENet2(f__1) RDBs_out = [] for i in range(self.D): x = self.RDBs[i](x) RDBs_out.append(x) x = self.GFF(torch.cat(RDBs_out,1)) x += f__1 return self.UPNet(x) ================================================ FILE: src/model/utils/__init__.py ================================================ ================================================ FILE: src/model/utils/tools.py ================================================ import os import torch import numpy as np from PIL import Image import torch.nn.functional as F def normalize(x): return x.mul_(2).add_(-1) def same_padding(images, ksizes, strides, rates): assert len(images.size()) == 4 batch_size, channel, rows, cols = images.size() out_rows = (rows + strides[0] - 1) // strides[0] out_cols = (cols + strides[1] - 1) // strides[1] effective_k_row = (ksizes[0] - 1) * rates[0] + 1 effective_k_col = (ksizes[1] - 1) * rates[1] + 1 padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) # Pad the input padding_top = int(padding_rows / 2.) padding_left = int(padding_cols / 2.) padding_bottom = padding_rows - padding_top padding_right = padding_cols - padding_left paddings = (padding_left, padding_right, padding_top, padding_bottom) images = torch.nn.ZeroPad2d(paddings)(images) return images def extract_image_patches(images, ksizes, strides, rates, padding='same'): """ Extract patches from images and put them in the C output dimension. :param padding: :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for each dimension of images :param strides: [stride_rows, stride_cols] :param rates: [dilation_rows, dilation_cols] :return: A Tensor """ assert len(images.size()) == 4 assert padding in ['same', 'valid'] batch_size, channel, height, width = images.size() if padding == 'same': images = same_padding(images, ksizes, strides, rates) elif padding == 'valid': pass else: raise NotImplementedError('Unsupported padding type: {}.\ Only "same" or "valid" are supported.'.format(padding)) unfold = torch.nn.Unfold(kernel_size=ksizes, dilation=rates, padding=0, stride=strides) patches = unfold(images) return patches # [N, C*k*k, L], L is the total number of such blocks def reduce_mean(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.mean(x, dim=i, keepdim=keepdim) return x def reduce_std(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.std(x, dim=i, keepdim=keepdim) return x def reduce_sum(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.sum(x, dim=i, keepdim=keepdim) return x ================================================ FILE: src/model/vdsr.py ================================================ from model import common import torch.nn as nn import torch.nn.init as init url = { 'r20f64': '' } def make_model(args, parent=False): return VDSR(args) class VDSR(nn.Module): def __init__(self, args, conv=common.default_conv): super(VDSR, self).__init__() n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 self.url = url['r{}f{}'.format(n_resblocks, n_feats)] self.sub_mean = common.MeanShift(args.rgb_range) self.add_mean = common.MeanShift(args.rgb_range, sign=1) def basic_block(in_channels, out_channels, act): return common.BasicBlock( conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=act ) # define body module m_body = [] m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True))) for _ in range(n_resblocks - 2): m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True))) m_body.append(basic_block(n_feats, args.n_colors, None)) self.body = nn.Sequential(*m_body) def forward(self, x): x = self.sub_mean(x) res = self.body(x) res += x x = self.add_mean(res) return x ================================================ FILE: src/option.py ================================================ import argparse import template parser = argparse.ArgumentParser(description='EDSR and MDSR') parser.add_argument('--debug', action='store_true', help='Enables debug mode') parser.add_argument('--template', default='.', help='You can set various templates in option.py') # Hardware specifications parser.add_argument('--n_threads', type=int, default=18, help='number of threads for data loading') parser.add_argument('--cpu', action='store_true', help='use cpu only') parser.add_argument('--n_GPUs', type=int, default=1, help='number of GPUs') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--local_rank',type=int, default=0) # Data specifications parser.add_argument('--dir_data', type=str, default='../../../', help='dataset directory') parser.add_argument('--dir_demo', type=str, default='../Demo', help='demo image directory') parser.add_argument('--data_train', type=str, default='DIV2K', help='train dataset name') parser.add_argument('--data_test', type=str, default='DIV2K', help='test dataset name') parser.add_argument('--data_range', type=str, default='1-800/801-810', help='train/test data range') parser.add_argument('--ext', type=str, default='sep', help='dataset file extension') parser.add_argument('--scale', type=str, default='4', help='super resolution scale') parser.add_argument('--patch_size', type=int, default=192, help='output patch size') parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB') parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') parser.add_argument('--chunk_size',type=int,default=144, help='attention bucket size') parser.add_argument('--n_hashes',type=int,default=4, help='number of hash rounds') parser.add_argument('--chop', action='store_true', help='enable memory-efficient forward') parser.add_argument('--no_augment', action='store_true', help='do not use data augmentation') # Model specifications parser.add_argument('--model', default='EDSR', help='model name') parser.add_argument('--act', type=str, default='relu', help='activation function') parser.add_argument('--pre_train', type=str, default='.', help='pre-trained model directory') parser.add_argument('--extend', type=str, default='.', help='pre-trained model directory') parser.add_argument('--n_resblocks', type=int, default=20, help='number of residual blocks') parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') parser.add_argument('--shift_mean', default=True, help='subtract pixel mean from the input') parser.add_argument('--dilation', action='store_true', help='use dilated convolution') parser.add_argument('--precision', type=str, default='single', choices=('single', 'half'), help='FP precision for test (single | half)') # Option for Residual dense network (RDN) parser.add_argument('--G0', type=int, default=64, help='default number of filters. (Use in RDN)') parser.add_argument('--RDNkSize', type=int, default=3, help='default kernel size. (Use in RDN)') parser.add_argument('--RDNconfig', type=str, default='B', help='parameters config of RDN. (Use in RDN)') parser.add_argument('--depth', type=int, default=12, help='number of residual groups') # Option for Residual channel attention network (RCAN) parser.add_argument('--n_resgroups', type=int, default=10, help='number of residual groups') parser.add_argument('--reduction', type=int, default=16, help='number of feature maps reduction') # Training specifications parser.add_argument('--reset', action='store_true', help='reset the training') parser.add_argument('--test_every', type=int, default=1000, help='do test per every N batches') parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=16, help='input batch size for training') parser.add_argument('--split_batch', type=int, default=1, help='split the batch into smaller chunks') parser.add_argument('--self_ensemble', action='store_true', help='use self-ensemble method for test') parser.add_argument('--test_only', action='store_true', help='set this option to test the model') parser.add_argument('--gan_k', type=int, default=1, help='k value for adversarial loss') # Optimization specifications parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') parser.add_argument('--decay', type=str, default='200', help='learning rate decay type') parser.add_argument('--gamma', type=float, default=0.5, help='learning rate decay factor for step decay') parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM', 'RMSprop'), help='optimizer to use (SGD | ADAM | RMSprop)') parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum') parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') parser.add_argument('--epsilon', type=float, default=1e-8, help='ADAM epsilon for numerical stability') parser.add_argument('--weight_decay', type=float, default=0, help='weight decay') parser.add_argument('--gclip', type=float, default=0, help='gradient clipping threshold (0 = no clipping)') # Loss specifications parser.add_argument('--loss', type=str, default='1*L1', help='loss function configuration') parser.add_argument('--skip_threshold', type=float, default='1e8', help='skipping batch that has large error') # Log specifications parser.add_argument('--save', type=str, default='test', help='file name to save') parser.add_argument('--load', type=str, default='', help='file name to load') parser.add_argument('--resume', type=int, default=0, help='resume from specific checkpoint') parser.add_argument('--save_models', action='store_true', help='save all intermediate models') parser.add_argument('--print_every', type=int, default=100, help='how many batches to wait before logging training status') parser.add_argument('--save_results', action='store_true', help='save output results') parser.add_argument('--save_gt', action='store_true', help='save low-resolution and high-resolution images together') args = parser.parse_args() template.set_template(args) args.scale = list(map(lambda x: int(x), args.scale.split('+'))) args.data_train = args.data_train.split('+') args.data_test = args.data_test.split('+') if args.epochs == 0: args.epochs = 1e8 for arg in vars(args): if vars(args)[arg] == 'True': vars(args)[arg] = True elif vars(args)[arg] == 'False': vars(args)[arg] = False ================================================ FILE: src/template.py ================================================ def set_template(args): # Set the templates here if args.template.find('jpeg') >= 0: args.data_train = 'DIV2K_jpeg' args.data_test = 'DIV2K_jpeg' args.epochs = 200 args.decay = '100' if args.template.find('EDSR_paper') >= 0: args.model = 'EDSR' args.n_resblocks = 32 args.n_feats = 256 args.res_scale = 0.1 if args.template.find('MDSR') >= 0: args.model = 'MDSR' args.patch_size = 48 args.epochs = 650 if args.template.find('DDBPN') >= 0: args.model = 'DDBPN' args.patch_size = 128 args.scale = '4' args.data_test = 'Set5' args.batch_size = 20 args.epochs = 1000 args.decay = '500' args.gamma = 0.1 args.weight_decay = 1e-4 args.loss = '1*MSE' if args.template.find('GAN') >= 0: args.epochs = 200 args.lr = 5e-5 args.decay = '150' if args.template.find('RCAN') >= 0: args.model = 'RCAN' args.n_resgroups = 10 args.n_resblocks = 20 args.n_feats = 64 args.chop = True if args.template.find('VDSR') >= 0: args.model = 'VDSR' args.n_resblocks = 20 args.n_feats = 64 args.patch_size = 41 args.lr = 1e-1 ================================================ FILE: src/trainer.py ================================================ import os import math from decimal import Decimal import utility import torch import torch.nn.utils as utils from tqdm import tqdm class Trainer(): def __init__(self, args, loader, my_model, my_loss, ckp): self.args = args self.scale = args.scale self.ckp = ckp self.loader_train = loader.loader_train self.loader_test = loader.loader_test self.model = my_model self.loss = my_loss self.optimizer = utility.make_optimizer(args, self.model) if self.args.load != '': self.optimizer.load(ckp.dir, epoch=len(ckp.log)) self.error_last = 1e8 def train(self): self.loss.step() epoch = self.optimizer.get_last_epoch() + 1 lr = self.optimizer.get_lr() self.ckp.write_log( '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) ) self.loss.start_log() self.model.train() timer_data, timer_model = utility.timer(), utility.timer() # TEMP self.loader_train.dataset.set_scale(0) for batch, (lr, hr, _,) in enumerate(self.loader_train): lr, hr = self.prepare(lr, hr) timer_data.hold() timer_model.tic() self.optimizer.zero_grad() sr = self.model(lr, 0) loss = self.loss(sr, hr) loss.backward() if self.args.gclip > 0: utils.clip_grad_value_( self.model.parameters(), self.args.gclip ) self.optimizer.step() timer_model.hold() if (batch + 1) % self.args.print_every == 0: self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format( (batch + 1) * self.args.batch_size, len(self.loader_train.dataset), self.loss.display_loss(batch), timer_model.release(), timer_data.release())) timer_data.tic() self.loss.end_log(len(self.loader_train)) self.error_last = self.loss.log[-1, -1] self.optimizer.schedule() def test(self): torch.set_grad_enabled(False) epoch = self.optimizer.get_last_epoch() self.ckp.write_log('\nEvaluation:') self.ckp.add_log( torch.zeros(1, len(self.loader_test), len(self.scale)) ) self.model.eval() timer_test = utility.timer() if self.args.save_results: self.ckp.begin_background() for idx_data, d in enumerate(self.loader_test): for idx_scale, scale in enumerate(self.scale): d.dataset.set_scale(idx_scale) for lr, hr, filename in tqdm(d, ncols=80): lr, hr = self.prepare(lr, hr) sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range) save_list = [sr] self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr( sr, hr, scale, self.args.rgb_range, dataset=d ) if self.args.save_gt: save_list.extend([lr, hr]) if self.args.save_results: self.ckp.save_results(d, filename[0], save_list, scale) self.ckp.log[-1, idx_data, idx_scale] /= len(d) best = self.ckp.log.max(0) self.ckp.write_log( '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( d.dataset.name, scale, self.ckp.log[-1, idx_data, idx_scale], best[0][idx_data, idx_scale], best[1][idx_data, idx_scale] + 1 ) ) self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc())) self.ckp.write_log('Saving...') if self.args.save_results: self.ckp.end_background() if not self.args.test_only: self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch)) self.ckp.write_log( 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True ) torch.set_grad_enabled(True) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) return [_prepare(a) for a in args] def terminate(self): if self.args.test_only: self.test() return True else: epoch = self.optimizer.get_last_epoch() + 1 return epoch >= self.args.epochs ================================================ FILE: src/utility.py ================================================ import os import math import time import datetime from multiprocessing import Process from multiprocessing import Queue import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import imageio import torch import torch.optim as optim import torch.optim.lr_scheduler as lrs class timer(): def __init__(self): self.acc = 0 self.tic() def tic(self): self.t0 = time.time() def toc(self, restart=False): diff = time.time() - self.t0 if restart: self.t0 = time.time() return diff def hold(self): self.acc += self.toc() def release(self): ret = self.acc self.acc = 0 return ret def reset(self): self.acc = 0 class checkpoint(): def __init__(self, args): self.args = args self.ok = True self.log = torch.Tensor() now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') if not args.load: if not args.save: args.save = now self.dir = os.path.join('..', 'experiment', args.save) else: self.dir = os.path.join('..', 'experiment', args.load) if os.path.exists(self.dir): self.log = torch.load(self.get_path('psnr_log.pt')) print('Continue from epoch {}...'.format(len(self.log))) else: args.load = '' if args.reset: os.system('rm -rf ' + self.dir) args.load = '' os.makedirs(self.dir, exist_ok=True) os.makedirs(self.get_path('model'), exist_ok=True) for d in args.data_test: os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True) open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w' self.log_file = open(self.get_path('log.txt'), open_type) with open(self.get_path('config.txt'), open_type) as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') self.n_processes = 8 def get_path(self, *subdir): return os.path.join(self.dir, *subdir) def save(self, trainer, epoch, is_best=False): trainer.model.save(self.get_path('model'), epoch, is_best=is_best) trainer.loss.save(self.dir) trainer.loss.plot_loss(self.dir, epoch) self.plot_psnr(epoch) trainer.optimizer.save(self.dir) torch.save(self.log, self.get_path('psnr_log.pt')) def add_log(self, log): self.log = torch.cat([self.log, log]) def write_log(self, log, refresh=False): print(log) self.log_file.write(log + '\n') if refresh: self.log_file.close() self.log_file = open(self.get_path('log.txt'), 'a') def done(self): self.log_file.close() def plot_psnr(self, epoch): axis = np.linspace(1, epoch, epoch) for idx_data, d in enumerate(self.args.data_test): label = 'SR on {}'.format(d) fig = plt.figure() plt.title(label) for idx_scale, scale in enumerate(self.args.scale): plt.plot( axis, self.log[:, idx_data, idx_scale].numpy(), label='Scale {}'.format(scale) ) plt.legend() plt.xlabel('Epochs') plt.ylabel('PSNR') plt.grid(True) plt.savefig(self.get_path('test_{}.pdf'.format(d))) plt.close(fig) def begin_background(self): self.queue = Queue() def bg_target(queue): while True: if not queue.empty(): filename, tensor = queue.get() if filename is None: break imageio.imwrite(filename, tensor.numpy()) self.process = [ Process(target=bg_target, args=(self.queue,)) \ for _ in range(self.n_processes) ] for p in self.process: p.start() def end_background(self): for _ in range(self.n_processes): self.queue.put((None, None)) while not self.queue.empty(): time.sleep(1) for p in self.process: p.join() def save_results(self, dataset, filename, save_list, scale): if self.args.save_results: filename = self.get_path( 'results-{}'.format(dataset.dataset.name), '{}_x{}_'.format(filename, scale) ) postfix = ('SR', 'LR', 'HR') for v, p in zip(save_list, postfix): normalized = v[0].mul(255 / self.args.rgb_range) tensor_cpu = normalized.byte().permute(1, 2, 0).cpu() self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu)) def quantize(img, rgb_range): pixel_range = 255 / rgb_range return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) def calc_psnr(sr, hr, scale, rgb_range, dataset=None): if hr.nelement() == 1: return 0 diff = (sr - hr) / rgb_range if dataset and dataset.dataset.benchmark: shave = scale if diff.size(1) > 1: gray_coeffs = [65.738, 129.057, 25.064] convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 diff = diff.mul(convert).sum(dim=1) else: shave = scale + 6 valid = diff[..., shave:-shave, shave:-shave] mse = valid.pow(2).mean() return -10 * math.log10(mse) def make_optimizer(args, target): ''' make optimizer and scheduler together ''' # optimizer trainable = filter(lambda x: x.requires_grad, target.parameters()) kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay} if args.optimizer == 'SGD': optimizer_class = optim.SGD kwargs_optimizer['momentum'] = args.momentum elif args.optimizer == 'ADAM': optimizer_class = optim.Adam kwargs_optimizer['betas'] = args.betas kwargs_optimizer['eps'] = args.epsilon elif args.optimizer == 'RMSprop': optimizer_class = optim.RMSprop kwargs_optimizer['eps'] = args.epsilon # scheduler milestones = list(map(lambda x: int(x), args.decay.split('-'))) kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma} scheduler_class = lrs.MultiStepLR class CustomOptimizer(optimizer_class): def __init__(self, *args, **kwargs): super(CustomOptimizer, self).__init__(*args, **kwargs) def _register_scheduler(self, scheduler_class, **kwargs): self.scheduler = scheduler_class(self, **kwargs) def save(self, save_dir): torch.save(self.state_dict(), self.get_dir(save_dir)) def load(self, load_dir, epoch=1): self.load_state_dict(torch.load(self.get_dir(load_dir))) if epoch > 1: for _ in range(epoch): self.scheduler.step() def get_dir(self, dir_path): return os.path.join(dir_path, 'optimizer.pt') def schedule(self): self.scheduler.step() def get_lr(self): return self.scheduler.get_lr()[0] def get_last_epoch(self): return self.scheduler.last_epoch optimizer = CustomOptimizer(trainable, **kwargs_optimizer) optimizer._register_scheduler(scheduler_class, **kwargs_scheduler) return optimizer ================================================ FILE: src/utils/__init__.py ================================================ ================================================ FILE: src/utils/tools.py ================================================ import os import torch import numpy as np from PIL import Image import torch.nn.functional as F def normalize(x): return x.mul_(2).add_(-1) def same_padding(images, ksizes, strides, rates): assert len(images.size()) == 4 batch_size, channel, rows, cols = images.size() out_rows = (rows + strides[0] - 1) // strides[0] out_cols = (cols + strides[1] - 1) // strides[1] effective_k_row = (ksizes[0] - 1) * rates[0] + 1 effective_k_col = (ksizes[1] - 1) * rates[1] + 1 padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows) padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols) # Pad the input padding_top = int(padding_rows / 2.) padding_left = int(padding_cols / 2.) padding_bottom = padding_rows - padding_top padding_right = padding_cols - padding_left paddings = (padding_left, padding_right, padding_top, padding_bottom) images = torch.nn.ZeroPad2d(paddings)(images) return images def extract_image_patches(images, ksizes, strides, rates, padding='same'): """ Extract patches from images and put them in the C output dimension. :param padding: :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for each dimension of images :param strides: [stride_rows, stride_cols] :param rates: [dilation_rows, dilation_cols] :return: A Tensor """ assert len(images.size()) == 4 assert padding in ['same', 'valid'] batch_size, channel, height, width = images.size() if padding == 'same': images = same_padding(images, ksizes, strides, rates) elif padding == 'valid': pass else: raise NotImplementedError('Unsupported padding type: {}.\ Only "same" or "valid" are supported.'.format(padding)) unfold = torch.nn.Unfold(kernel_size=ksizes, dilation=rates, padding=0, stride=strides) patches = unfold(images) return patches # [N, C*k*k, L], L is the total number of such blocks def reduce_mean(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.mean(x, dim=i, keepdim=keepdim) return x def reduce_std(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.std(x, dim=i, keepdim=keepdim) return x def reduce_sum(x, axis=None, keepdim=False): if not axis: axis = range(len(x.shape)) for i in sorted(axis, reverse=True): x = torch.sum(x, dim=i, keepdim=keepdim) return x ================================================ FILE: src/videotester.py ================================================ import os import math import utility from data import common import torch import cv2 from tqdm import tqdm class VideoTester(): def __init__(self, args, my_model, ckp): self.args = args self.scale = args.scale self.ckp = ckp self.model = my_model self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo)) def test(self): torch.set_grad_enabled(False) self.ckp.write_log('\nEvaluation on video:') self.model.eval() timer_test = utility.timer() for idx_scale, scale in enumerate(self.scale): vidcap = cv2.VideoCapture(self.args.dir_demo) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) vidwri = cv2.VideoWriter( self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)), cv2.VideoWriter_fourcc(*'XVID'), vidcap.get(cv2.CAP_PROP_FPS), ( int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) ) ) tqdm_test = tqdm(range(total_frames), ncols=80) for _ in tqdm_test: success, lr = vidcap.read() if not success: break lr, = common.set_channel(lr, n_channels=self.args.n_colors) lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range) lr, = self.prepare(lr.unsqueeze(0)) sr = self.model(lr, idx_scale) sr = utility.quantize(sr, self.args.rgb_range).squeeze(0) normalized = sr * 255 / self.args.rgb_range ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() vidwri.write(ndarr) vidcap.release() vidwri.release() self.ckp.write_log( 'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True ) torch.set_grad_enabled(True) def prepare(self, *args): device = torch.device('cpu' if self.args.cpu else 'cuda') def _prepare(tensor): if self.args.precision == 'half': tensor = tensor.half() return tensor.to(device) return [_prepare(a) for a in args]