Repository: HarukiYqM/Non-Local-Sparse-Attention
Branch: main
Commit: fb3d42255832
Files: 43
Total size: 117.5 KB
Directory structure:
gitextract_22vbrkxt/
├── README.md
└── src/
├── __init__.py
├── data/
│ ├── __init__.py
│ ├── benchmark.py
│ ├── common.py
│ ├── demo.py
│ ├── div2k.py
│ ├── div2kjpeg.py
│ ├── sr291.py
│ ├── srdata.py
│ └── video.py
├── dataloader.py
├── demo.sh
├── loss/
│ ├── __init__.py
│ ├── __loss__.py
│ ├── adversarial.py
│ ├── demo.sh
│ ├── discriminator.py
│ ├── hash.py
│ └── vgg.py
├── main.py
├── model/
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── attention.py
│ ├── common.py
│ ├── ddbpn.py
│ ├── edsr.py
│ ├── mdsr.py
│ ├── mssr.py
│ ├── nlsn.py
│ ├── rcan.py
│ ├── rdn.py
│ ├── utils/
│ │ ├── __init__.py
│ │ └── tools.py
│ └── vdsr.py
├── option.py
├── template.py
├── trainer.py
├── utility.py
├── utils/
│ ├── __init__.py
│ └── tools.py
└── videotester.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Image Super-Resolution with Non-Local Sparse Attention
This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-Local Sparse Attention", CVPR2021, [[Link]](https://openaccess.thecvf.com/content/CVPR2021/papers/Mei_Image_Super-Resolution_With_Non-Local_Sparse_Attention_CVPR_2021_paper.pdf)
The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and test on Ubuntu 18.04 environment (Python3.6, PyTorch >= 1.1.0) with V100 GPUs.
## Contents
1. [Introduction](#introduction)
2. [Train](#train)
3. [Test](#test)
5. [Citation](#citation)
6. [Acknowledgements](#acknowledgements)
## Introduction
Both Non-Local (NL) operation and sparse representa-tion are crucial for Single Image Super-Resolution (SISR).In this paper, we investigate their combinations and proposea novel Non-Local Sparse Attention (NLSA) with dynamicsparse attention pattern. NLSA is designed to retain long-range modeling capability from NL operation while enjoying robustness and high-efficiency of sparse representation.Specifically, NLSA rectifies non-local attention with spherical locality sensitive hashing (LSH) that partitions the input space into hash buckets of related features. For everyquery signal, NLSA assigns a bucket to it and only computes attention within the bucket. The resulting sparse attention prevents the model from attending to locations thatare noisy and less-informative, while reducing the computa-tional cost from quadratic to asymptotic linear with respectto the spatial size. Extensive experiments validate the effectiveness and efficiency of NLSA. With a few non-local sparseattention modules, our architecture, called non-local sparsenetwork (NLSN), reaches state-of-the-art performance forSISR quantitatively and qualitatively.

Non-Local Sparse Attention.

Non-Local Sparse Network.
## Train
### Prepare training data
1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
2. Specify '--dir_data' based on the HR and LR images path.
For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
### Begin to train
1. (optional) Download pretrained models for our paper.
Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing)
2. Cd to 'src', run the following script to train models.
**Example command is in the file 'demo.sh'.**
```bash
# Example X2 SR
python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K
```
## Test
### Quick start
1. Download benchmark datasets from [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar)
1. (optional) Download pretrained models for our paper.
All the models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing)
2. Cd to 'src', run the following scripts.
**Example command is in the file 'demo.sh'.**
```bash
# No self-ensemble: NLSN
# Example X2 SR
python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only
```
## Citation
If you find the code helpful in your resarch or work, please cite the following papers.
```
@InProceedings{Mei_2021_CVPR,
author = {Mei, Yiqun and Fan, Yuchen and Zhou, Yuqian},
title = {Image Super-Resolution With Non-Local Sparse Attention},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {3517-3526}
}
@InProceedings{Lim_2017_CVPR_Workshops,
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {July},
year = {2017}
}
```
## Acknowledgements
This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch). We thank the authors for sharing their codes.
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/data/__init__.py
================================================
from importlib import import_module
#from dataloader import MSDataLoader
from torch.utils.data import dataloader
from torch.utils.data import ConcatDataset
# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
def __init__(self, datasets):
super(MyConcatDataset, self).__init__(datasets)
self.train = datasets[0].train
def set_scale(self, idx_scale):
for d in self.datasets:
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
class Data:
def __init__(self, args):
self.loader_train = None
if not args.test_only:
datasets = []
for d in args.data_train:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d))
self.loader_train = dataloader.DataLoader(
MyConcatDataset(datasets),
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test.append(
dataloader.DataLoader(
testset,
batch_size=1,
shuffle=False,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)
)
================================================
FILE: src/data/benchmark.py
================================================
import os
from data import common
from data import srdata
import numpy as np
import torch
import torch.utils.data as data
class Benchmark(srdata.SRData):
def __init__(self, args, name='', train=True, benchmark=True):
super(Benchmark, self).__init__(
args, name=name, train=train, benchmark=True
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'benchmark', self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('', '.png')
================================================
FILE: src/data/common.py
================================================
import random
import numpy as np
import skimage.color as sc
import torch
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
ih, iw = args[0].shape[:2]
if not input_large:
p = scale if multi else 1
tp = p * patch_size
ip = tp // scale
else:
tp = patch_size
ip = patch_size
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]
return ret
def set_channel(*args, n_channels=3):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
================================================
FILE: src/data/demo.py
================================================
import os
from data import common
import numpy as np
import imageio
import torch
import torch.utils.data as data
class Demo(data.Dataset):
def __init__(self, args, name='Demo', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.benchmark = benchmark
self.filelist = []
for f in os.listdir(args.dir_demo):
if f.find('.png') >= 0 or f.find('.jp') >= 0:
self.filelist.append(os.path.join(args.dir_demo, f))
self.filelist.sort()
def __getitem__(self, idx):
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
lr = imageio.imread(self.filelist[idx])
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
return lr_t, -1, filename
def __len__(self):
return len(self.filelist)
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
================================================
FILE: src/data/div2k.py
================================================
import os
from data import srdata
class DIV2K(srdata.SRData):
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
super(DIV2K, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
if self.input_large: self.dir_lr += 'L'
================================================
FILE: src/data/div2kjpeg.py
================================================
import os
from data import srdata
from data import div2k
class DIV2KJPEG(div2k.DIV2K):
def __init__(self, args, name='', train=True, benchmark=False):
self.q_factor = int(name.replace('DIV2K-Q', ''))
super(DIV2KJPEG, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'DIV2K')
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
)
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.jpg')
================================================
FILE: src/data/sr291.py
================================================
from data import srdata
class SR291(srdata.SRData):
def __init__(self, args, name='SR291', train=True, benchmark=False):
super(SR291, self).__init__(args, name=name)
================================================
FILE: src/data/srdata.py
================================================
import os
import glob
import random
import pickle
from data import common
import numpy as np
import imageio
import torch
import torch.utils.data as data
class SRData(data.Dataset):
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
for s in self.scale:
os.makedirs(
os.path.join(
self.dir_lr.replace(self.apath, path_bin),
'X{}'.format(s)
),
exist_ok=True
)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
if train:
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
# Below functions as used to prepare images
def _scan(self):
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, s, self.ext[1]
)
))
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1], filename
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
else:
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
else:
return idx
def _load_file(self, idx):
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch(self, lr, hr):
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size,
scale=scale,
multi=(len(self.scale) > 1),
input_large=self.input_large
)
if not self.args.no_augment: lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)
================================================
FILE: src/data/video.py
================================================
import os
from data import common
import cv2
import numpy as np
import imageio
import torch
import torch.utils.data as data
class Video(data.Dataset):
def __init__(self, args, name='Video', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.do_eval = False
self.benchmark = benchmark
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
self.vidcap = cv2.VideoCapture(args.dir_demo)
self.n_frames = 0
self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
def __getitem__(self, idx):
success, lr = self.vidcap.read()
if success:
self.n_frames += 1
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
else:
vidcap.release()
return None
def __len__(self):
return self.total_frames
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
================================================
FILE: src/dataloader.py
================================================
import threading
import random
import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog
from torch._six import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
collate._use_shared_memory = True
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
w.daemon = True
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__(
*args, **kwargs, num_workers=cfg.n_threads
)
self.scale = cfg.scale
def __iter__(self):
return _MSDataLoaderIter(self)
================================================
FILE: src/demo.sh
================================================
#!/bin/bash
#Train x2
python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K
#Test x2
python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only
================================================
FILE: src/loss/__init__.py
================================================
import os
from importlib import import_module
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(
loss_type[3:],
rgb_range=args.rgb_range
)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(
args,
loss_type
)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
self.log = torch.Tensor()
device = torch.device('cpu' if args.cpu else 'cuda')
self.loss_module.to(device)
if args.precision == 'half': self.loss_module.half()
if not args.cpu and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs))
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
for i, l in enumerate(self.loss):
if l['function'] is not None:
loss = l['function'](sr, hr)
effective_loss = l['weight'] * loss
losses.append(effective_loss)
self.log[-1, i] += effective_loss.item()
elif l['type'] == 'DIS':
self.log[-1, i] += self.loss[i - 1]['function'].loss
loss_sum = sum(losses)
if len(self.loss) > 1:
self.log[-1, -1] += loss_sum.item()
return loss_sum
def step(self):
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
l.scheduler.step()
def start_log(self):
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
def end_log(self, n_batches):
self.log[-1].div_(n_batches)
def display_loss(self, batch):
n_samples = batch + 1
log = []
for l, c in zip(self.loss, self.log[-1]):
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
return ''.join(log)
def plot_loss(self, apath, epoch):
axis = np.linspace(1, epoch, epoch)
for i, l in enumerate(self.loss):
label = '{} Loss'.format(l['type'])
fig = plt.figure()
plt.title(label)
plt.plot(axis, self.log[:, i].numpy(), label=label)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
plt.close(fig)
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module
def save(self, apath):
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
def load(self, apath, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
self.load_state_dict(torch.load(
os.path.join(apath, 'loss.pt'),
**kwargs
))
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
for _ in range(len(self.log)): l.scheduler.step()
================================================
FILE: src/loss/__loss__.py
================================================
================================================
FILE: src/loss/adversarial.py
================================================
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
if gan_type == 'WGAN_GP':
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type == 'GAN':
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type == 'GAN':
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
================================================
FILE: src/loss/demo.sh
================================================
================================================
FILE: src/loss/discriminator.py
================================================
from model import common
import torch.nn as nn
class Discriminator(nn.Module):
'''
output is not normalized
'''
def __init__(self, args):
super(Discriminator, self).__init__()
in_channels = args.n_colors
out_channels = 64
depth = 7
def _block(_in_channels, _out_channels, stride=1):
return nn.Sequential(
nn.Conv2d(
_in_channels,
_out_channels,
3,
padding=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
m_features = [_block(in_channels, out_channels)]
for i in range(depth):
in_channels = out_channels
if i % 2 == 1:
stride = 1
out_channels *= 2
else:
stride = 2
m_features.append(_block(in_channels, out_channels, stride=stride))
patch_size = args.patch_size // (2**((depth + 1) // 2))
m_classifier = [
nn.Linear(out_channels * patch_size**2, 1024),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(1024, 1)
]
self.features = nn.Sequential(*m_features)
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
features = self.features(x)
output = self.classifier(features.view(features.size(0), -1))
return output
================================================
FILE: src/loss/hash.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class HASH(nn.Module):
def __init__(self):
super(HASH, self).__init__()
self.l1 = nn.L1Loss()
def forward(self, sr, qk, orders, hr, m=3):
#hash loss
qk = F.normalize(qk, p=2, dim=1, eps=5e-5)
N,C,H,W = qk.shape
qk = qk.view(N,C,H*W)
qk_t = qk.permute(0,2,1).contiguous()
similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W]
orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W]
orders_t = torch.transpose(orders,1,2)
dist = torch.pow(orders-orders_t,2)
ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1))
ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1))
loss = 0.005*(ls+ld)+self.l1(sr,hr)
return loss
================================================
FILE: src/loss/vgg.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
================================================
FILE: src/main.py
================================================
import torch
import utility
import data
import model
import loss
from option import args
from trainer import Trainer
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
if __name__ == '__main__':
main()
================================================
FILE: src/model/LICENSE
================================================
MIT License
Copyright (c) 2018 Sanghyun Son
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: src/model/README.md
================================================
# EDSR-PyTorch

This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**.
You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).
If you find our work useful in your research or publication, please cite our work:
[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** 2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
```
@InProceedings{Lim_2017_CVPR_Workshops,
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {July},
year = {2017}
}
```
We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images.
**Differences between Torch version**
* Codes are much more compact. (Removed all unnecessary parts.)
* Models are smaller. (About half.)
* Slightly better performances.
* Training and evaluation requires less memory.
* Python-based.
## Dependencies
* Python 3.6
* PyTorch >= 0.4.0
* numpy
* skimage
* **imageio**
* matplotlib
* tqdm
**Recent updates**
* July 22, 2018
* Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
* Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
## Code
Clone this repository into any place you want.
```bash
git clone https://github.com/thstkdgus35/EDSR-PyTorch
cd EDSR-PyTorch
```
## Quick start (Demo)
You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/``) We support **png** and **jpeg** files.
Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
```bash
cd src # You are now in */EDSR-PyTorch/src
sh demo.sh
```
You can find the result images from ```experiment/test/results``` folder.
| Model | Scale | File name (.pt) | Parameters | ****PSNR** |
| --- | --- | --- | --- | --- |
| **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB |
| | | *EDSR_x2 | 40.7 M | 35.03 dB |
| | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB |
| | | *EDSR_x3 | 43.7 M | 31.26 dB |
| | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB |
| | | *EDSR_x4 | 43.1 M | 29.25 dB |
| **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB |
| | | *MDSR | 7.95 M| 34.92 dB |
| | 3 | MDSR_baseline | | 30.94 dB |
| | | *MDSR | | 31.22 dB |
| | 4 | MDSR_baseline | | 28.97 dB |
| | | *MDSR | | 29.24 dB |
*Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB)
**We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored.
You can evaluate your models with widely-used benchmark datasets:
[Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),
[Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),
[B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
[Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data `` to evaluate the EDSR and MDSR with the benchmarks.
## How to train EDSR and MDSR
We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located.
We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.
If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.
You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train ``` argument.
```bash
cd src # You are now in */EDSR-PyTorch/src
sh demo.sh
```
**Update log**
* Jan 04, 2018
* Many parts are re-written. You cannot use previous scripts and models directly.
* Pre-trained MDSR is temporarily disabled.
* Training details are included.
* Jan 09, 2018
* Missing files are included (```src/data/MyImage.py```).
* Some links are fixed.
* Jan 16, 2018
* Memory efficient forward function is implemented.
* Add --chop_forward argument to your script to enable it.
* Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)
* Feb 21, 2018
* Fixed the problem when loading pre-trained multi-gpu model.
* Added pre-trained scale 2 baseline model.
* This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models.
* PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.
* Feb 23, 2018
* Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version.
* With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.
* New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)
* With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)
* If you cannot make the binary pack, just use the default setting (``--ext img``).
* Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.
* Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)
* Added performance comparison between Torch7 model and PyTorch models.
* Mar 5, 2018
* All baseline models are uploaded.
* Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images.
* Mar 11, 2018
* Fixed some typos in the code and script.
* Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
* Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected.
* Mar 20, 2018
* Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
* Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
* Changed the behavior of skip_batch.
* Mar 29, 2018
* We now provide all models from our paper.
* We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble.
* ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
* Some codes and script are re-written.
* Apr 9, 2018
* VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.
* Many codes are refactored. If there exists a bug, please report it.
* [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L.
* Apr 26, 2018
* Compatible with PyTorch 0.4.0
* Please use the legacy/0.3.1 branch if you are using the old version of PyTorch.
* Minor bug fixes
================================================
FILE: src/model/__init__.py
================================================
import os
from importlib import import_module
import torch
import torch.nn as nn
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self, args, ckp):
super(Model, self).__init__()
print('Making model...')
self.scale = args.scale
self.idx_scale = 0
self.self_ensemble = args.self_ensemble
self.chop = args.chop
self.precision = args.precision
self.cpu = args.cpu
self.device = torch.device('cpu' if args.cpu else 'cuda')
self.n_GPUs = args.n_GPUs
self.save_models = args.save_models
module = import_module('model.' + args.model.lower())
self.model = module.make_model(args).to(self.device)
if args.precision == 'half': self.model.half()
if not args.cpu and args.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
self.load(
ckp.dir,
pre_train=args.pre_train,
resume=args.resume,
cpu=args.cpu
)
print(self.model, file=ckp.log_file)
def forward(self, x, idx_scale):
self.idx_scale = idx_scale
target = self.get_model()
if hasattr(target, 'set_scale'):
target.set_scale(idx_scale)
if self.self_ensemble and not self.training:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
return self.forward_x8(x, forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else:
return self.model(x)
def get_model(self):
if self.n_GPUs == 1:
return self.model
else:
return self.model.module
def state_dict(self, **kwargs):
target = self.get_model()
return target.state_dict(**kwargs)
def save(self, apath, epoch, is_best=False):
target = self.get_model()
torch.save(
target.state_dict(),
os.path.join(apath, 'model_latest.pt')
)
if is_best:
torch.save(
target.state_dict(),
os.path.join(apath, 'model_best.pt')
)
if self.save_models:
torch.save(
target.state_dict(),
os.path.join(apath, 'model_{}.pt'.format(epoch))
)
def load(self, apath, pre_train='.', resume=-1, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
if resume == -1:
self.get_model().load_state_dict(
torch.load(
os.path.join(apath, 'model_latest.pt'),
**kwargs
),
strict=False
)
elif resume == 0:
if pre_train != '.':
print('Loading model from {}'.format(pre_train))
self.get_model().load_state_dict(
torch.load(pre_train, **kwargs),
strict=False
)
else:
self.get_model().load_state_dict(
torch.load(
os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
**kwargs
),
strict=False
)
def forward_chop(self, x, shave=10, min_size=120000):
scale = self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4)
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
h_size +=4-h_size%4
w_size +=8-w_size%8
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
sr_list = []
for i in range(0, 4, n_GPUs):
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
sr_batch = self.model(lr_batch)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
sr_list = [
self.forward_chop(patch, shave=shave, min_size=min_size) \
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
output = x.new(b, c, h, w)
output[:, :, 0:h_half, 0:w_half] \
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
def forward_x8(self, x, forward_function):
def _transform(v, op):
if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
if self.precision == 'half': ret = ret.half()
return ret
lr_list = [x]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
sr_list = [forward_function(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
output = output_cat.mean(dim=0, keepdim=True)
return output
================================================
FILE: src/model/attention.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
class NonLocalSparseAttention(nn.Module):
def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
super(NonLocalSparseAttention,self).__init__()
self.chunk_size = chunk_size
self.n_hashes = n_hashes
self.reduction = reduction
self.res_scale = res_scale
self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
def LSH(self, hash_buckets, x):
#x: [N,H*W,C]
N = x.shape[0]
device = x.device
#generate random rotation matrix
rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
#locality sensitive hashing
rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2]
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
#get hash codes
hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W]
#add offsets to avoid hash codes overlapping between hash rounds
offsets = torch.arange(self.n_hashes, device=device)
offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1))
hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W]
return hash_codes
def add_adjacent_buckets(self, x):
x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)
x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)
return torch.cat([x, x_extra_back,x_extra_forward], dim=3)
def forward(self, input):
N,_,H,W = input.shape
x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)
y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)
L,C = x_embed.shape[-2:]
#number of hash buckets/hash bits
hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)
#get assigned hash codes/bucket number
hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
hash_codes = hash_codes.detach()
#group elements with same hash code by sorting
_, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W]
_, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order
mod_indices = (indices % L) #now range from (0->H*W)
x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C]
#pad the embedding if it cannot be divided by chunk_size
padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction))
if padding:
pad_x = x_att_buckets[:,:,-padding:,:].clone()
pad_y = y_att_buckets[:,:,-padding:,:].clone()
x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)
x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)
#allow attend to adjacent buckets
x_match = self.add_adjacent_buckets(x_match)
y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
#unormalized attention score
raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
#softmax
bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
score = torch.exp(raw_score - bucket_score) #(after softmax)
bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
#attention
ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
#if padded, then remove extra elements
if padding:
ret = ret[:,:,:-padding,:].clone()
bucket_score = bucket_score[:,:,:-padding].clone()
#recover the original order
ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
#weighted sum multi-round attention
ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
probs = nn.functional.softmax(bucket_score,dim=1)
ret = torch.sum(ret * probs, dim=1)
ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
return ret
class NonLocalAttention(nn.Module):
def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv):
super(NonLocalAttention, self).__init__()
self.res_scale = res_scale
self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
def forward(self, input):
x_embed_1 = self.conv_match1(input)
x_embed_2 = self.conv_match2(input)
x_assembly = self.conv_assembly(input)
N,C,H,W = x_embed_1.shape
x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))
x_embed_2 = x_embed_2.view(N,C,H*W)
score = torch.matmul(x_embed_1, x_embed_2)
score = F.softmax(score, dim=2)
x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)
x_final = torch.matmul(score, x_assembly)
return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input
================================================
FILE: src/model/common.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2),stride=stride, bias=bias)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
for p in self.parameters():
p.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
bn=False, act=nn.PReLU()):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.PReLU(), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias=bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias=bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
================================================
FILE: src/model/ddbpn.py
================================================
# Deep Back-Projection Networks For Super-Resolution
# https://arxiv.org/abs/1803.02735
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return DDBPN(args)
def projection_conv(in_channels, out_channels, scale, up=True):
kernel_size, stride, padding = {
2: (6, 2, 2),
4: (8, 4, 2),
8: (12, 8, 2)
}[scale]
if up:
conv_f = nn.ConvTranspose2d
else:
conv_f = nn.Conv2d
return conv_f(
in_channels, out_channels, kernel_size,
stride=stride, padding=padding
)
class DenseProjection(nn.Module):
def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
super(DenseProjection, self).__init__()
if bottleneck:
self.bottleneck = nn.Sequential(*[
nn.Conv2d(in_channels, nr, 1),
nn.PReLU(nr)
])
inter_channels = nr
else:
self.bottleneck = None
inter_channels = in_channels
self.conv_1 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
self.conv_2 = nn.Sequential(*[
projection_conv(nr, inter_channels, scale, not up),
nn.PReLU(inter_channels)
])
self.conv_3 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
def forward(self, x):
if self.bottleneck is not None:
x = self.bottleneck(x)
a_0 = self.conv_1(x)
b_0 = self.conv_2(a_0)
e = b_0.sub(x)
a_1 = self.conv_3(e)
out = a_0.add(a_1)
return out
class DDBPN(nn.Module):
def __init__(self, args):
super(DDBPN, self).__init__()
scale = args.scale[0]
n0 = 128
nr = 32
self.depth = 6
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
initial = [
nn.Conv2d(args.n_colors, n0, 3, padding=1),
nn.PReLU(n0),
nn.Conv2d(n0, nr, 1),
nn.PReLU(nr)
]
self.initial = nn.Sequential(*initial)
self.upmodules = nn.ModuleList()
self.downmodules = nn.ModuleList()
channels = nr
for i in range(self.depth):
self.upmodules.append(
DenseProjection(channels, nr, scale, True, i > 1)
)
if i != 0:
channels += nr
channels = nr
for i in range(self.depth - 1):
self.downmodules.append(
DenseProjection(channels, nr, scale, False, i != 0)
)
channels += nr
reconstruction = [
nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
]
self.reconstruction = nn.Sequential(*reconstruction)
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
x = self.sub_mean(x)
x = self.initial(x)
h_list = []
l_list = []
for i in range(self.depth - 1):
if i == 0:
l = x
else:
l = torch.cat(l_list, dim=1)
h_list.append(self.upmodules[i](l))
l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
out = self.reconstruction(torch.cat(h_list, dim=1))
out = self.add_mean(out)
return out
================================================
FILE: src/model/edsr.py
================================================
from model import common
from model import attention
import torch.nn as nn
def make_model(args, parent=False):
if args.dilation:
from model import dilated
return EDSR(args, dilated.dilated_conv)
else:
return EDSR(args)
class EDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblock = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
#self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale);
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblock//2)
]
#m_body.append(self.msa)
for _ in range(n_resblock//2):
m_body.append( common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
nn.Conv2d(
n_feats, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
================================================
FILE: src/model/mdsr.py
================================================
from model import common
import torch.nn as nn
def make_model(args, parent=False):
return MDSR(args)
class MDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(MDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.scale_idx = 0
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
m_head = [conv(args.n_colors, n_feats, kernel_size)]
self.pre_process = nn.ModuleList([
nn.Sequential(
common.ResBlock(conv, n_feats, 5, act=act),
common.ResBlock(conv, n_feats, 5, act=act)
) for _ in args.scale
])
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.upsample = nn.ModuleList([
common.Upsampler(
conv, s, n_feats, act=False
) for s in args.scale
])
m_tail = [conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
x = self.pre_process[self.scale_idx](x)
res = self.body(x)
res += x
x = self.upsample[self.scale_idx](res)
x = self.tail(x)
x = self.add_mean(x)
return x
def set_scale(self, scale_idx):
self.scale_idx = scale_idx
================================================
FILE: src/model/mssr.py
================================================
from model import common
import torch.nn as nn
import torch
from model.attention import ContextualAttention,NonLocalAttention
def make_model(args, parent=False):
return MSSR(args)
class MultisourceProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(MultisourceProjection, self).__init__()
self.up_attention = ContextualAttention(scale=2)
self.down_attention = NonLocalAttention()
self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)
def forward(self,x):
down_map = self.upsample(self.down_attention(x))
up_map = self.up_attention(x)
err = self.encoder(up_map-down_map)
final_map = down_map + err
return final_map
class RecurrentProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(RecurrentProjection, self).__init__()
self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
#self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())
def forward(self, x):
x_up = self.multi_source_projection_1(x)
x_down = self.down_sample_1(x_up)
error_up = self.error_encode_1(x-x_down)
h_estimate_1 = x_up + error_up
x_up_2 = self.multi_source_projection_2(h_estimate_1)
x_down_2 = self.down_sample_3(x_up_2)
error_up_2 = self.error_encode_2(x-x_down_2)
h_estimate_2 = x_up_2 + error_up_2
x_final = self.post_conv(self.down_sample_4(h_estimate_2))
return x_final, h_estimate_2
class MSSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(MSSR, self).__init__()
#n_convblock = args.n_convblocks
n_feats = args.n_feats
self.depth = args.depth
kernel_size = 3
scale = args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),
common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]
# define multiple reconstruction module
self.body = RecurrentProjection(n_feats)
# define tail module
m_tail = [
nn.Conv2d(
n_feats*self.depth, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.tail = nn.Sequential(*m_tail)
def forward(self,input):
x = self.sub_mean(input)
x = self.head(x)
bag = []
for i in range(self.depth):
x, h_estimate = self.body(x)
bag.append(h_estimate)
h_feature = torch.cat(bag,dim=1)
h_final = self.tail(h_feature)
return self.add_mean(h_final)
================================================
FILE: src/model/nlsn.py
================================================
from model import common
from model import attention
import torch.nn as nn
def make_model(args, parent=False):
if args.dilation:
from model import dilated
return NLSN(args, dilated.dilated_conv)
else:
return NLSN(args)
class NLSN(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(NLSN, self).__init__()
n_resblock = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [attention.NonLocalSparseAttention(
channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)]
for i in range(n_resblock):
m_body.append( common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
))
if (i+1)%8==0:
m_body.append(attention.NonLocalSparseAttention(
channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
nn.Conv2d(
n_feats, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
================================================
FILE: src/model/rcan.py
================================================
## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks
## https://arxiv.org/abs/1807.02758
from model import common
import torch.nn as nn
import torch
def make_model(args, parent=False):
return RCAN(args)
## Channel Attention (CA) Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
#self.a = torch.nn.Parameter(torch.Tensor([0]))
#self.a.requires_grad=True
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
#res = self.body(x).mul(self.res_scale)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(RCAN, self).__init__()
self.a = nn.Parameter(torch.Tensor([0]))
self.a.requires_grad=True
n_resgroups = args.n_resgroups
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
scale = args.scale[0]
act = nn.ReLU(True)
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('msa') or name.find('a') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('msa') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
================================================
FILE: src/model/rdn.py
================================================
# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return RDN(args)
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(*[
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
nn.ReLU()
])
def forward(self, x):
out = self.conv(x)
return torch.cat((x, out), 1)
class RDB(nn.Module):
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
super(RDB, self).__init__()
G0 = growRate0
G = growRate
C = nConvLayers
convs = []
for c in range(C):
convs.append(RDB_Conv(G0 + c*G, G))
self.convs = nn.Sequential(*convs)
# Local Feature Fusion
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
def forward(self, x):
return self.LFF(self.convs(x)) + x
class RDN(nn.Module):
def __init__(self, args):
super(RDN, self).__init__()
r = args.scale[0]
G0 = args.G0
kSize = args.RDNkSize
# number of RDB blocks, conv layers, out channels
self.D, C, G = {
'A': (20, 6, 32),
'B': (16, 8, 64),
}[args.RDNconfig]
# Shallow feature extraction net
self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
# Redidual dense blocks and dense feature fusion
self.RDBs = nn.ModuleList()
for i in range(self.D):
self.RDBs.append(
RDB(growRate0 = G0, growRate = G, nConvLayers = C)
)
# Global Feature Fusion
self.GFF = nn.Sequential(*[
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
])
# Up-sampling net
if r == 2 or r == 3:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(r),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
elif r == 4:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
else:
raise ValueError("scale must be 2 or 3 or 4.")
def forward(self, x):
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out,1))
x += f__1
return self.UPNet(x)
================================================
FILE: src/model/utils/__init__.py
================================================
================================================
FILE: src/model/utils/tools.py
================================================
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def normalize(x):
return x.mul_(2).add_(-1)
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
================================================
FILE: src/model/vdsr.py
================================================
from model import common
import torch.nn as nn
import torch.nn.init as init
url = {
'r20f64': ''
}
def make_model(args, parent=False):
return VDSR(args)
class VDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(VDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
def basic_block(in_channels, out_channels, act):
return common.BasicBlock(
conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=act
)
# define body module
m_body = []
m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
for _ in range(n_resblocks - 2):
m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
m_body.append(basic_block(n_feats, args.n_colors, None))
self.body = nn.Sequential(*m_body)
def forward(self, x):
x = self.sub_mean(x)
res = self.body(x)
res += x
x = self.add_mean(res)
return x
================================================
FILE: src/option.py
================================================
import argparse
import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--debug', action='store_true',
help='Enables debug mode')
parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=18,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--local_rank',type=int, default=0)
# Data specifications
parser.add_argument('--dir_data', type=str, default='../../../',
help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../Demo',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=192,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--chunk_size',type=int,default=144,
help='attention bucket size')
parser.add_argument('--n_hashes',type=int,default=4,
help='number of hash rounds')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=20,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Option for Residual dense network (RDN)
parser.add_argument('--G0', type=int, default=64,
help='default number of filters. (Use in RDN)')
parser.add_argument('--RDNkSize', type=int, default=3,
help='default kernel size. (Use in RDN)')
parser.add_argument('--RDNconfig', type=str, default='B',
help='parameters config of RDN. (Use in RDN)')
parser.add_argument('--depth', type=int, default=12,
help='number of residual groups')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,
help='k value for adversarial loss')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='test',
help='file name to save')
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args()
template.set_template(args)
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
================================================
FILE: src/template.py
================================================
def set_template(args):
# Set the templates here
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1
================================================
FILE: src/trainer.py
================================================
import os
import math
from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model)
if self.args.load != '':
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
self.error_last = 1e8
def train(self):
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr()
self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()
self.model.train()
timer_data, timer_model = utility.timer(), utility.timer()
# TEMP
self.loader_train.dataset.set_scale(0)
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
sr = self.model(lr, 0)
loss = self.loss(sr, hr)
loss.backward()
if self.args.gclip > 0:
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.optimizer.step()
timer_model.hold()
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
timer_data.tic()
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs
================================================
FILE: src/utility.py
================================================
import os
import math
import time
import datetime
from multiprocessing import Process
from multiprocessing import Queue
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import imageio
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time()
def toc(self, restart=False):
diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self):
self.acc += self.toc()
def release(self):
ret = self.acc
self.acc = 0
return ret
def reset(self):
self.acc = 0
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if not args.load:
if not args.save:
args.save = now
self.dir = os.path.join('..', 'experiment', args.save)
else:
self.dir = os.path.join('..', 'experiment', args.load)
if os.path.exists(self.dir):
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log)))
else:
args.load = ''
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = ''
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type)
with open(self.get_path('config.txt'), open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
self.n_processes = 8
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt'))
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.get_path('log.txt'), 'a')
def done(self):
self.log_file.close()
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
def begin_background(self):
self.queue = Queue()
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
]
for p in self.process: p.start()
def end_background(self):
for _ in range(self.n_processes): self.queue.put((None, None))
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, dataset, filename, save_list, scale):
if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if hr.nelement() == 1: return 0
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
gray_coeffs = [65.738, 129.057, 25.064]
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
diff = diff.mul(convert).sum(dim=1)
else:
shave = scale + 6
valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def make_optimizer(args, target):
'''
make optimizer and scheduler together
'''
# optimizer
trainable = filter(lambda x: x.requires_grad, target.parameters())
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'SGD':
optimizer_class = optim.SGD
kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM':
optimizer_class = optim.Adam
kwargs_optimizer['betas'] = args.betas
kwargs_optimizer['eps'] = args.epsilon
elif args.optimizer == 'RMSprop':
optimizer_class = optim.RMSprop
kwargs_optimizer['eps'] = args.epsilon
# scheduler
milestones = list(map(lambda x: int(x), args.decay.split('-')))
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
scheduler_class = lrs.MultiStepLR
class CustomOptimizer(optimizer_class):
def __init__(self, *args, **kwargs):
super(CustomOptimizer, self).__init__(*args, **kwargs)
def _register_scheduler(self, scheduler_class, **kwargs):
self.scheduler = scheduler_class(self, **kwargs)
def save(self, save_dir):
torch.save(self.state_dict(), self.get_dir(save_dir))
def load(self, load_dir, epoch=1):
self.load_state_dict(torch.load(self.get_dir(load_dir)))
if epoch > 1:
for _ in range(epoch): self.scheduler.step()
def get_dir(self, dir_path):
return os.path.join(dir_path, 'optimizer.pt')
def schedule(self):
self.scheduler.step()
def get_lr(self):
return self.scheduler.get_lr()[0]
def get_last_epoch(self):
return self.scheduler.last_epoch
optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
return optimizer
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/tools.py
================================================
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def normalize(x):
return x.mul_(2).add_(-1)
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
================================================
FILE: src/videotester.py
================================================
import os
import math
import utility
from data import common
import torch
import cv2
from tqdm import tqdm
class VideoTester():
def __init__(self, args, my_model, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.model = my_model
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
vidwri = cv2.VideoWriter(
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
cv2.VideoWriter_fourcc(*'XVID'),
vidcap.get(cv2.CAP_PROP_FPS),
(
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
)
tqdm_test = tqdm(range(total_frames), ncols=80)
for _ in tqdm_test:
success, lr = vidcap.read()
if not success: break
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
lr, = self.prepare(lr.unsqueeze(0))
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
normalized = sr * 255 / self.args.rgb_range
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
vidwri.write(ndarr)
vidcap.release()
vidwri.release()
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]