Repository: HarukiYqM/Non-Local-Sparse-Attention
Branch: main
Commit: fb3d42255832
Files: 43
Total size: 117.5 KB
Directory structure:
gitextract_22vbrkxt/
├── README.md
└── src/
├── __init__.py
├── data/
│ ├── __init__.py
│ ├── benchmark.py
│ ├── common.py
│ ├── demo.py
│ ├── div2k.py
│ ├── div2kjpeg.py
│ ├── sr291.py
│ ├── srdata.py
│ └── video.py
├── dataloader.py
├── demo.sh
├── loss/
│ ├── __init__.py
│ ├── __loss__.py
│ ├── adversarial.py
│ ├── demo.sh
│ ├── discriminator.py
│ ├── hash.py
│ └── vgg.py
├── main.py
├── model/
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── attention.py
│ ├── common.py
│ ├── ddbpn.py
│ ├── edsr.py
│ ├── mdsr.py
│ ├── mssr.py
│ ├── nlsn.py
│ ├── rcan.py
│ ├── rdn.py
│ ├── utils/
│ │ ├── __init__.py
│ │ └── tools.py
│ └── vdsr.py
├── option.py
├── template.py
├── trainer.py
├── utility.py
├── utils/
│ ├── __init__.py
│ └── tools.py
└── videotester.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Image Super-Resolution with Non-Local Sparse Attention
This repository is for NLSN introduced in the following paper "Image Super-Resolution with Non-Local Sparse Attention", CVPR2021, [[Link]](https://openaccess.thecvf.com/content/CVPR2021/papers/Mei_Image_Super-Resolution_With_Non-Local_Sparse_Attention_CVPR_2021_paper.pdf)
The code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and test on Ubuntu 18.04 environment (Python3.6, PyTorch >= 1.1.0) with V100 GPUs.
## Contents
1. [Introduction](#introduction)
2. [Train](#train)
3. [Test](#test)
5. [Citation](#citation)
6. [Acknowledgements](#acknowledgements)
## Introduction
Both Non-Local (NL) operation and sparse representa-tion are crucial for Single Image Super-Resolution (SISR).In this paper, we investigate their combinations and proposea novel Non-Local Sparse Attention (NLSA) with dynamicsparse attention pattern. NLSA is designed to retain long-range modeling capability from NL operation while enjoying robustness and high-efficiency of sparse representation.Specifically, NLSA rectifies non-local attention with spherical locality sensitive hashing (LSH) that partitions the input space into hash buckets of related features. For everyquery signal, NLSA assigns a bucket to it and only computes attention within the bucket. The resulting sparse attention prevents the model from attending to locations thatare noisy and less-informative, while reducing the computa-tional cost from quadratic to asymptotic linear with respectto the spatial size. Extensive experiments validate the effectiveness and efficiency of NLSA. With a few non-local sparseattention modules, our architecture, called non-local sparsenetwork (NLSN), reaches state-of-the-art performance forSISR quantitatively and qualitatively.

Non-Local Sparse Attention.

Non-Local Sparse Network.
## Train
### Prepare training data
1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).
2. Specify '--dir_data' based on the HR and LR images path.
For more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).
### Begin to train
1. (optional) Download pretrained models for our paper.
Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing)
2. Cd to 'src', run the following script to train models.
**Example command is in the file 'demo.sh'.**
```bash
# Example X2 SR
python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K
```
## Test
### Quick start
1. Download benchmark datasets from [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar)
1. (optional) Download pretrained models for our paper.
All the models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing)
2. Cd to 'src', run the following scripts.
**Example command is in the file 'demo.sh'.**
```bash
# No self-ensemble: NLSN
# Example X2 SR
python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only
```
## Citation
If you find the code helpful in your resarch or work, please cite the following papers.
```
@InProceedings{Mei_2021_CVPR,
author = {Mei, Yiqun and Fan, Yuchen and Zhou, Yuqian},
title = {Image Super-Resolution With Non-Local Sparse Attention},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {3517-3526}
}
@InProceedings{Lim_2017_CVPR_Workshops,
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {July},
year = {2017}
}
```
## Acknowledgements
This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch). We thank the authors for sharing their codes.
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/data/__init__.py
================================================
from importlib import import_module
#from dataloader import MSDataLoader
from torch.utils.data import dataloader
from torch.utils.data import ConcatDataset
# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
def __init__(self, datasets):
super(MyConcatDataset, self).__init__(datasets)
self.train = datasets[0].train
def set_scale(self, idx_scale):
for d in self.datasets:
if hasattr(d, 'set_scale'): d.set_scale(idx_scale)
class Data:
def __init__(self, args):
self.loader_train = None
if not args.test_only:
datasets = []
for d in args.data_train:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
datasets.append(getattr(m, module_name)(args, name=d))
self.loader_train = dataloader.DataLoader(
MyConcatDataset(datasets),
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)
self.loader_test = []
for d in args.data_test:
if d in ['Set5', 'Set14', 'B100', 'Urban100']:
m = import_module('data.benchmark')
testset = getattr(m, 'Benchmark')(args, train=False, name=d)
else:
module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
m = import_module('data.' + module_name.lower())
testset = getattr(m, module_name)(args, train=False, name=d)
self.loader_test.append(
dataloader.DataLoader(
testset,
batch_size=1,
shuffle=False,
pin_memory=not args.cpu,
num_workers=args.n_threads,
)
)
================================================
FILE: src/data/benchmark.py
================================================
import os
from data import common
from data import srdata
import numpy as np
import torch
import torch.utils.data as data
class Benchmark(srdata.SRData):
def __init__(self, args, name='', train=True, benchmark=True):
super(Benchmark, self).__init__(
args, name=name, train=train, benchmark=True
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'benchmark', self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
if self.input_large:
self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')
else:
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('', '.png')
================================================
FILE: src/data/common.py
================================================
import random
import numpy as np
import skimage.color as sc
import torch
def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
ih, iw = args[0].shape[:2]
if not input_large:
p = scale if multi else 1
tp = p * patch_size
ip = tp // scale
else:
tp = patch_size
ip = patch_size
ix = random.randrange(0, iw - ip + 1)
iy = random.randrange(0, ih - ip + 1)
if not input_large:
tx, ty = scale * ix, scale * iy
else:
tx, ty = ix, iy
ret = [
args[0][iy:iy + ip, ix:ix + ip, :],
*[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
]
return ret
def set_channel(*args, n_channels=3):
def _set_channel(img):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img
return [_set_channel(a) for a in args]
def np2Tensor(*args, rgb_range=255):
def _np2Tensor(img):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)
return tensor
return [_np2Tensor(a) for a in args]
def augment(*args, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
def _augment(img):
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
return [_augment(a) for a in args]
================================================
FILE: src/data/demo.py
================================================
import os
from data import common
import numpy as np
import imageio
import torch
import torch.utils.data as data
class Demo(data.Dataset):
def __init__(self, args, name='Demo', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.benchmark = benchmark
self.filelist = []
for f in os.listdir(args.dir_demo):
if f.find('.png') >= 0 or f.find('.jp') >= 0:
self.filelist.append(os.path.join(args.dir_demo, f))
self.filelist.sort()
def __getitem__(self, idx):
filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
lr = imageio.imread(self.filelist[idx])
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
return lr_t, -1, filename
def __len__(self):
return len(self.filelist)
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
================================================
FILE: src/data/div2k.py
================================================
import os
from data import srdata
class DIV2K(srdata.SRData):
def __init__(self, args, name='DIV2K', train=True, benchmark=False):
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
super(DIV2K, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _scan(self):
names_hr, names_lr = super(DIV2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
names_lr = [n[self.begin - 1:self.end] for n in names_lr]
return names_hr, names_lr
def _set_filesystem(self, dir_data):
super(DIV2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')
if self.input_large: self.dir_lr += 'L'
================================================
FILE: src/data/div2kjpeg.py
================================================
import os
from data import srdata
from data import div2k
class DIV2KJPEG(div2k.DIV2K):
def __init__(self, args, name='', train=True, benchmark=False):
self.q_factor = int(name.replace('DIV2K-Q', ''))
super(DIV2KJPEG, self).__init__(
args, name=name, train=train, benchmark=benchmark
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, 'DIV2K')
self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
self.dir_lr = os.path.join(
self.apath, 'DIV2K_Q{}'.format(self.q_factor)
)
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.jpg')
================================================
FILE: src/data/sr291.py
================================================
from data import srdata
class SR291(srdata.SRData):
def __init__(self, args, name='SR291', train=True, benchmark=False):
super(SR291, self).__init__(args, name=name)
================================================
FILE: src/data/srdata.py
================================================
import os
import glob
import random
import pickle
from data import common
import numpy as np
import imageio
import torch
import torch.utils.data as data
class SRData(data.Dataset):
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.input_large = (args.model == 'VDSR')
self.scale = args.scale
self.idx_scale = 0
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr, list_lr = self._scan()
if args.ext.find('img') >= 0 or benchmark:
self.images_hr, self.images_lr = list_hr, list_lr
elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
for s in self.scale:
os.makedirs(
os.path.join(
self.dir_lr.replace(self.apath, path_bin),
'X{}'.format(s)
),
exist_ok=True
)
self.images_hr, self.images_lr = [], [[] for _ in self.scale]
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(args.ext, h, b, verbose=True)
for i, ll in enumerate(list_lr):
for l in ll:
b = l.replace(self.apath, path_bin)
b = b.replace(self.ext[1], '.pt')
self.images_lr[i].append(b)
self._check_and_load(args.ext, l, b, verbose=True)
if train:
n_patches = args.batch_size * args.test_every
n_images = len(args.data_train) * len(self.images_hr)
if n_images == 0:
self.repeat = 0
else:
self.repeat = max(n_patches // n_images, 1)
# Below functions as used to prepare images
def _scan(self):
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
)
names_lr = [[] for _ in self.scale]
for f in names_hr:
filename, _ = os.path.splitext(os.path.basename(f))
for si, s in enumerate(self.scale):
names_lr[si].append(os.path.join(
self.dir_lr, 'X{}/{}x{}{}'.format(
s, filename, s, self.ext[1]
)
))
return names_hr, names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
if self.input_large: self.dir_lr += 'L'
self.ext = ('.png', '.png')
def _check_and_load(self, ext, img, f, verbose=True):
if not os.path.isfile(f) or ext.find('reset') >= 0:
if verbose:
print('Making a binary: {}'.format(f))
with open(f, 'wb') as _f:
pickle.dump(imageio.imread(img), _f)
def __getitem__(self, idx):
lr, hr, filename = self._load_file(idx)
pair = self.get_patch(lr, hr)
pair = common.set_channel(*pair, n_channels=self.args.n_colors)
pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)
return pair_t[0], pair_t[1], filename
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
else:
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
else:
return idx
def _load_file(self, idx):
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
f_lr = self.images_lr[self.idx_scale][idx]
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
lr = imageio.imread(f_lr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = pickle.load(_f)
with open(f_lr, 'rb') as _f:
lr = pickle.load(_f)
return lr, hr, filename
def get_patch(self, lr, hr):
scale = self.scale[self.idx_scale]
if self.train:
lr, hr = common.get_patch(
lr, hr,
patch_size=self.args.patch_size,
scale=scale,
multi=(len(self.scale) > 1),
input_large=self.input_large
)
if not self.args.no_augment: lr, hr = common.augment(lr, hr)
else:
ih, iw = lr.shape[:2]
hr = hr[0:ih * scale, 0:iw * scale]
return lr, hr
def set_scale(self, idx_scale):
if not self.input_large:
self.idx_scale = idx_scale
else:
self.idx_scale = random.randint(0, len(self.scale) - 1)
================================================
FILE: src/data/video.py
================================================
import os
from data import common
import cv2
import numpy as np
import imageio
import torch
import torch.utils.data as data
class Video(data.Dataset):
def __init__(self, args, name='Video', train=False, benchmark=False):
self.args = args
self.name = name
self.scale = args.scale
self.idx_scale = 0
self.train = False
self.do_eval = False
self.benchmark = benchmark
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
self.vidcap = cv2.VideoCapture(args.dir_demo)
self.n_frames = 0
self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
def __getitem__(self, idx):
success, lr = self.vidcap.read()
if success:
self.n_frames += 1
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
else:
vidcap.release()
return None
def __len__(self):
return self.total_frames
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
================================================
FILE: src/dataloader.py
================================================
import threading
import random
import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog
from torch._six import queue
def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
try:
collate._use_shared_memory = True
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if r is None:
assert done_event.is_set()
return
elif done_event.is_set():
continue
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
pass
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.sample_iter = iter(self.batch_sampler)
base_seed = torch.LongTensor(1).random_().item()
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
self.done_event = multiprocessing.Event()
base_seed = torch.LongTensor(1).random_()[0]
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue()
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
index_queue,
self.worker_result_queue,
self.done_event,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
w.daemon = True
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
if self.pin_memory:
self.data_queue = queue.Queue()
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(
self.worker_result_queue,
self.data_queue,
torch.cuda.current_device(),
self.done_event
)
)
pin_memory_thread.daemon = True
pin_memory_thread.start()
self.pin_memory_thread = pin_memory_thread
else:
self.data_queue = self.worker_result_queue
_utils.signal_handling._set_worker_pids(
id(self), tuple(w.pid for w in self.workers)
)
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(self, cfg, *args, **kwargs):
super(MSDataLoader, self).__init__(
*args, **kwargs, num_workers=cfg.n_threads
)
self.scale = cfg.scale
def __iter__(self):
return _MSDataLoaderIter(self)
================================================
FILE: src/demo.sh
================================================
#!/bin/bash
#Train x2
python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K
#Test x2
python main.py --dir_data ../../ --model NLSN --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1 --pre_train model_x2.pt --test_only
================================================
FILE: src/loss/__init__.py
================================================
import os
from importlib import import_module
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(
loss_type[3:],
rgb_range=args.rgb_range
)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(
args,
loss_type
)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
self.log = torch.Tensor()
device = torch.device('cpu' if args.cpu else 'cuda')
self.loss_module.to(device)
if args.precision == 'half': self.loss_module.half()
if not args.cpu and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs))
if args.load != '': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
for i, l in enumerate(self.loss):
if l['function'] is not None:
loss = l['function'](sr, hr)
effective_loss = l['weight'] * loss
losses.append(effective_loss)
self.log[-1, i] += effective_loss.item()
elif l['type'] == 'DIS':
self.log[-1, i] += self.loss[i - 1]['function'].loss
loss_sum = sum(losses)
if len(self.loss) > 1:
self.log[-1, -1] += loss_sum.item()
return loss_sum
def step(self):
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
l.scheduler.step()
def start_log(self):
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
def end_log(self, n_batches):
self.log[-1].div_(n_batches)
def display_loss(self, batch):
n_samples = batch + 1
log = []
for l, c in zip(self.loss, self.log[-1]):
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
return ''.join(log)
def plot_loss(self, apath, epoch):
axis = np.linspace(1, epoch, epoch)
for i, l in enumerate(self.loss):
label = '{} Loss'.format(l['type'])
fig = plt.figure()
plt.title(label)
plt.plot(axis, self.log[:, i].numpy(), label=label)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
plt.close(fig)
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module
def save(self, apath):
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
def load(self, apath, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
self.load_state_dict(torch.load(
os.path.join(apath, 'loss.pt'),
**kwargs
))
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
for _ in range(len(self.log)): l.scheduler.step()
================================================
FILE: src/loss/__loss__.py
================================================
================================================
FILE: src/loss/adversarial.py
================================================
import utility
from types import SimpleNamespace
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.dis = discriminator.Discriminator(args)
if gan_type == 'WGAN_GP':
# see https://arxiv.org/pdf/1704.00028.pdf pp.4
optim_dict = {
'optimizer': 'ADAM',
'betas': (0, 0.9),
'epsilon': 1e-8,
'lr': 1e-5,
'weight_decay': args.weight_decay,
'decay': args.decay,
'gamma': args.gamma
}
optim_args = SimpleNamespace(**optim_dict)
else:
optim_args = args
self.optimizer = utility.make_optimizer(optim_args, self.dis)
def forward(self, fake, real):
# updating discriminator...
self.loss = 0
fake_detach = fake.detach() # do not backpropagate through G
for _ in range(self.gan_k):
self.optimizer.zero_grad()
# d: B x 1 tensor
d_fake = self.dis(fake_detach)
d_real = self.dis(real)
retain_graph = False
if self.gan_type == 'GAN':
loss_d = self.bce(d_real, d_fake)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.dis(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake.mean(dim=0, keepdim=True)
better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
loss_d = self.bce(better_real, better_fake)
retain_graph = True
# Discriminator update
self.loss += loss_d.item()
loss_d.backward(retain_graph=retain_graph)
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.dis.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
# updating generator...
d_fake_bp = self.dis(fake) # for backpropagation, use fake as it is
if self.gan_type == 'GAN':
label_real = torch.ones_like(d_fake_bp)
loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_bp.mean()
elif self.gan_type == 'RGAN':
better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
loss_g = self.bce(better_fake, better_real)
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.dis.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
def bce(self, real, fake):
label_real = torch.ones_like(real)
label_fake = torch.zeros_like(fake)
bce_real = F.binary_cross_entropy_with_logits(real, label_real)
bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
bce_loss = bce_real + bce_fake
return bce_loss
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
================================================
FILE: src/loss/demo.sh
================================================
================================================
FILE: src/loss/discriminator.py
================================================
from model import common
import torch.nn as nn
class Discriminator(nn.Module):
'''
output is not normalized
'''
def __init__(self, args):
super(Discriminator, self).__init__()
in_channels = args.n_colors
out_channels = 64
depth = 7
def _block(_in_channels, _out_channels, stride=1):
return nn.Sequential(
nn.Conv2d(
_in_channels,
_out_channels,
3,
padding=1,
stride=stride,
bias=False
),
nn.BatchNorm2d(_out_channels),
nn.LeakyReLU(negative_slope=0.2, inplace=True)
)
m_features = [_block(in_channels, out_channels)]
for i in range(depth):
in_channels = out_channels
if i % 2 == 1:
stride = 1
out_channels *= 2
else:
stride = 2
m_features.append(_block(in_channels, out_channels, stride=stride))
patch_size = args.patch_size // (2**((depth + 1) // 2))
m_classifier = [
nn.Linear(out_channels * patch_size**2, 1024),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Linear(1024, 1)
]
self.features = nn.Sequential(*m_features)
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
features = self.features(x)
output = self.classifier(features.view(features.size(0), -1))
return output
================================================
FILE: src/loss/hash.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class HASH(nn.Module):
def __init__(self):
super(HASH, self).__init__()
self.l1 = nn.L1Loss()
def forward(self, sr, qk, orders, hr, m=3):
#hash loss
qk = F.normalize(qk, p=2, dim=1, eps=5e-5)
N,C,H,W = qk.shape
qk = qk.view(N,C,H*W)
qk_t = qk.permute(0,2,1).contiguous()
similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W]
orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W]
orders_t = torch.transpose(orders,1,2)
dist = torch.pow(orders-orders_t,2)
ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1))
ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1))
loss = 0.005*(ls+ld)+self.l1(sr,hr)
return loss
================================================
FILE: src/loss/vgg.py
================================================
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index.find('22') >= 0:
self.vgg = nn.Sequential(*modules[:8])
elif conv_index.find('54') >= 0:
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.parameters():
p.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
================================================
FILE: src/main.py
================================================
import torch
import utility
import data
import model
import loss
from option import args
from trainer import Trainer
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
def main():
global model
if args.data_test == ['video']:
from videotester import VideoTester
model = model.Model(args, checkpoint)
print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
t = VideoTester(args, model, checkpoint)
t.test()
else:
if checkpoint.ok:
loader = data.Data(args)
_model = model.Model(args, checkpoint)
print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))
_loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, _model, _loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()
if __name__ == '__main__':
main()
================================================
FILE: src/model/LICENSE
================================================
MIT License
Copyright (c) 2018 Sanghyun Son
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: src/model/README.md
================================================
# EDSR-PyTorch

This repository is an official PyTorch implementation of the paper **"Enhanced Deep Residual Networks for Single Image Super-Resolution"** from **CVPRW 2017, 2nd NTIRE**.
You can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).
If you find our work useful in your research or publication, please cite our work:
[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **"Enhanced Deep Residual Networks for Single Image Super-Resolution,"** <i>2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. </i> [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]
```
@InProceedings{Lim_2017_CVPR_Workshops,
author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},
title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {July},
year = {2017}
}
```
We provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images.
**Differences between Torch version**
* Codes are much more compact. (Removed all unnecessary parts.)
* Models are smaller. (About half.)
* Slightly better performances.
* Training and evaluation requires less memory.
* Python-based.
## Dependencies
* Python 3.6
* PyTorch >= 0.4.0
* numpy
* skimage
* **imageio**
* matplotlib
* tqdm
**Recent updates**
* July 22, 2018
* Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.
* Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).
## Code
Clone this repository into any place you want.
```bash
git clone https://github.com/thstkdgus35/EDSR-PyTorch
cd EDSR-PyTorch
```
## Quick start (Demo)
You can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files.
Run the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.
```bash
cd src # You are now in */EDSR-PyTorch/src
sh demo.sh
```
You can find the result images from ```experiment/test/results``` folder.
| Model | Scale | File name (.pt) | Parameters | ****PSNR** |
| --- | --- | --- | --- | --- |
| **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB |
| | | *EDSR_x2 | 40.7 M | 35.03 dB |
| | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB |
| | | *EDSR_x3 | 43.7 M | 31.26 dB |
| | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB |
| | | *EDSR_x4 | 43.1 M | 29.25 dB |
| **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB |
| | | *MDSR | 7.95 M| 34.92 dB |
| | 3 | MDSR_baseline | | 30.94 dB |
| | | *MDSR | | 31.22 dB |
| | 4 | MDSR_baseline | | 28.97 dB |
| | | *MDSR | | 29.24 dB |
*Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB)
**We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored.
You can evaluate your models with widely-used benchmark datasets:
[Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),
[Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),
[B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),
[Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).
For these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data <where_benchmark_folder_located>`` to evaluate the EDSR and MDSR with the benchmarks.
## How to train EDSR and MDSR
We used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).
Unpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located.
We recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.
If you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.
You can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train <x2 model>``` argument.
```bash
cd src # You are now in */EDSR-PyTorch/src
sh demo.sh
```
**Update log**
* Jan 04, 2018
* Many parts are re-written. You cannot use previous scripts and models directly.
* Pre-trained MDSR is temporarily disabled.
* Training details are included.
* Jan 09, 2018
* Missing files are included (```src/data/MyImage.py```).
* Some links are fixed.
* Jan 16, 2018
* Memory efficient forward function is implemented.
* Add --chop_forward argument to your script to enable it.
* Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)
* Feb 21, 2018
* Fixed the problem when loading pre-trained multi-gpu model.
* Added pre-trained scale 2 baseline model.
* This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models.
* PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.
* Feb 23, 2018
* Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version.
* With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.
* New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)
* With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)
* If you cannot make the binary pack, just use the default setting (``--ext img``).
* Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.
* Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)
* Added performance comparison between Torch7 model and PyTorch models.
* Mar 5, 2018
* All baseline models are uploaded.
* Now supports half-precision at test time. Use ``--precision half`` to enable it. This does not degrade the output images.
* Mar 11, 2018
* Fixed some typos in the code and script.
* Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.
* Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected.
* Mar 20, 2018
* Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.
* Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.
* Changed the behavior of skip_batch.
* Mar 29, 2018
* We now provide all models from our paper.
* We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble.
* ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.
* Some codes and script are re-written.
* Apr 9, 2018
* VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.
* Many codes are refactored. If there exists a bug, please report it.
* [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L.
* Apr 26, 2018
* Compatible with PyTorch 0.4.0
* Please use the legacy/0.3.1 branch if you are using the old version of PyTorch.
* Minor bug fixes
================================================
FILE: src/model/__init__.py
================================================
import os
from importlib import import_module
import torch
import torch.nn as nn
from torch.autograd import Variable
class Model(nn.Module):
def __init__(self, args, ckp):
super(Model, self).__init__()
print('Making model...')
self.scale = args.scale
self.idx_scale = 0
self.self_ensemble = args.self_ensemble
self.chop = args.chop
self.precision = args.precision
self.cpu = args.cpu
self.device = torch.device('cpu' if args.cpu else 'cuda')
self.n_GPUs = args.n_GPUs
self.save_models = args.save_models
module = import_module('model.' + args.model.lower())
self.model = module.make_model(args).to(self.device)
if args.precision == 'half': self.model.half()
if not args.cpu and args.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
self.load(
ckp.dir,
pre_train=args.pre_train,
resume=args.resume,
cpu=args.cpu
)
print(self.model, file=ckp.log_file)
def forward(self, x, idx_scale):
self.idx_scale = idx_scale
target = self.get_model()
if hasattr(target, 'set_scale'):
target.set_scale(idx_scale)
if self.self_ensemble and not self.training:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
return self.forward_x8(x, forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else:
return self.model(x)
def get_model(self):
if self.n_GPUs == 1:
return self.model
else:
return self.model.module
def state_dict(self, **kwargs):
target = self.get_model()
return target.state_dict(**kwargs)
def save(self, apath, epoch, is_best=False):
target = self.get_model()
torch.save(
target.state_dict(),
os.path.join(apath, 'model_latest.pt')
)
if is_best:
torch.save(
target.state_dict(),
os.path.join(apath, 'model_best.pt')
)
if self.save_models:
torch.save(
target.state_dict(),
os.path.join(apath, 'model_{}.pt'.format(epoch))
)
def load(self, apath, pre_train='.', resume=-1, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
if resume == -1:
self.get_model().load_state_dict(
torch.load(
os.path.join(apath, 'model_latest.pt'),
**kwargs
),
strict=False
)
elif resume == 0:
if pre_train != '.':
print('Loading model from {}'.format(pre_train))
self.get_model().load_state_dict(
torch.load(pre_train, **kwargs),
strict=False
)
else:
self.get_model().load_state_dict(
torch.load(
os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),
**kwargs
),
strict=False
)
def forward_chop(self, x, shave=10, min_size=120000):
scale = self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4)
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
h_size +=4-h_size%4
w_size +=8-w_size%8
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
sr_list = []
for i in range(0, 4, n_GPUs):
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
sr_batch = self.model(lr_batch)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
sr_list = [
self.forward_chop(patch, shave=shave, min_size=min_size) \
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
output = x.new(b, c, h, w)
output[:, :, 0:h_half, 0:w_half] \
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
def forward_x8(self, x, forward_function):
def _transform(v, op):
if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
if self.precision == 'half': ret = ret.half()
return ret
lr_list = [x]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
sr_list = [forward_function(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
output = output_cat.mean(dim=0, keepdim=True)
return output
================================================
FILE: src/model/attention.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import common
class NonLocalSparseAttention(nn.Module):
def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):
super(NonLocalSparseAttention,self).__init__()
self.chunk_size = chunk_size
self.n_hashes = n_hashes
self.reduction = reduction
self.res_scale = res_scale
self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)
self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)
def LSH(self, hash_buckets, x):
#x: [N,H*W,C]
N = x.shape[0]
device = x.device
#generate random rotation matrix
rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]
random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]
#locality sensitive hashing
rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2]
rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]
#get hash codes
hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W]
#add offsets to avoid hash codes overlapping between hash rounds
offsets = torch.arange(self.n_hashes, device=device)
offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1))
hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W]
return hash_codes
def add_adjacent_buckets(self, x):
x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)
x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)
return torch.cat([x, x_extra_back,x_extra_forward], dim=3)
def forward(self, input):
N,_,H,W = input.shape
x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)
y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)
L,C = x_embed.shape[-2:]
#number of hash buckets/hash bits
hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)
#get assigned hash codes/bucket number
hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]
hash_codes = hash_codes.detach()
#group elements with same hash code by sorting
_, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W]
_, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order
mod_indices = (indices % L) #now range from (0->H*W)
x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]
y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C]
#pad the embedding if it cannot be divided by chunk_size
padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0
x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]
y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction))
if padding:
pad_x = x_att_buckets[:,:,-padding:,:].clone()
pad_y = y_att_buckets[:,:,-padding:,:].clone()
x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)
y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)
x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]
y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))
x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)
#allow attend to adjacent buckets
x_match = self.add_adjacent_buckets(x_match)
y_att_buckets = self.add_adjacent_buckets(y_att_buckets)
#unormalized attention score
raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]
#softmax
bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)
score = torch.exp(raw_score - bucket_score) #(after softmax)
bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])
#attention
ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]
ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))
#if padded, then remove extra elements
if padding:
ret = ret[:,:,:-padding,:].clone()
bucket_score = bucket_score[:,:,:-padding].clone()
#recover the original order
ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]
ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]
bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]
#weighted sum multi-round attention
ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]
bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))
probs = nn.functional.softmax(bucket_score,dim=1)
ret = torch.sum(ret * probs, dim=1)
ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input
return ret
class NonLocalAttention(nn.Module):
def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv):
super(NonLocalAttention, self).__init__()
self.res_scale = res_scale
self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())
self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())
self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())
def forward(self, input):
x_embed_1 = self.conv_match1(input)
x_embed_2 = self.conv_match2(input)
x_assembly = self.conv_assembly(input)
N,C,H,W = x_embed_1.shape
x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))
x_embed_2 = x_embed_2.view(N,C,H*W)
score = torch.matmul(x_embed_1, x_embed_2)
score = F.softmax(score, dim=2)
x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)
x_final = torch.matmul(score, x_assembly)
return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input
================================================
FILE: src/model/common.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def batched_index_select(values, indices):
last_dim = values.shape[-1]
return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))
def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size//2),stride=stride, bias=bias)
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range,
rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
for p in self.parameters():
p.requires_grad = False
class BasicBlock(nn.Sequential):
def __init__(
self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,
bn=False, act=nn.PReLU()):
m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
if bn:
m.append(nn.BatchNorm2d(out_channels))
if act is not None:
m.append(act)
super(BasicBlock, self).__init__(*m)
class ResBlock(nn.Module):
def __init__(
self, conv, n_feats, kernel_size,
bias=True, bn=False, act=nn.PReLU(), res_scale=1):
super(ResBlock, self).__init__()
m = []
for i in range(2):
m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if i == 0:
m.append(act)
self.body = nn.Sequential(*m)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x).mul(self.res_scale)
res += x
return res
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feats, 4 * n_feats, 3, bias=bias))
m.append(nn.PixelShuffle(2))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
elif scale == 3:
m.append(conv(n_feats, 9 * n_feats, 3, bias=bias))
m.append(nn.PixelShuffle(3))
if bn:
m.append(nn.BatchNorm2d(n_feats))
if act == 'relu':
m.append(nn.ReLU(True))
elif act == 'prelu':
m.append(nn.PReLU(n_feats))
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
================================================
FILE: src/model/ddbpn.py
================================================
# Deep Back-Projection Networks For Super-Resolution
# https://arxiv.org/abs/1803.02735
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return DDBPN(args)
def projection_conv(in_channels, out_channels, scale, up=True):
kernel_size, stride, padding = {
2: (6, 2, 2),
4: (8, 4, 2),
8: (12, 8, 2)
}[scale]
if up:
conv_f = nn.ConvTranspose2d
else:
conv_f = nn.Conv2d
return conv_f(
in_channels, out_channels, kernel_size,
stride=stride, padding=padding
)
class DenseProjection(nn.Module):
def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
super(DenseProjection, self).__init__()
if bottleneck:
self.bottleneck = nn.Sequential(*[
nn.Conv2d(in_channels, nr, 1),
nn.PReLU(nr)
])
inter_channels = nr
else:
self.bottleneck = None
inter_channels = in_channels
self.conv_1 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
self.conv_2 = nn.Sequential(*[
projection_conv(nr, inter_channels, scale, not up),
nn.PReLU(inter_channels)
])
self.conv_3 = nn.Sequential(*[
projection_conv(inter_channels, nr, scale, up),
nn.PReLU(nr)
])
def forward(self, x):
if self.bottleneck is not None:
x = self.bottleneck(x)
a_0 = self.conv_1(x)
b_0 = self.conv_2(a_0)
e = b_0.sub(x)
a_1 = self.conv_3(e)
out = a_0.add(a_1)
return out
class DDBPN(nn.Module):
def __init__(self, args):
super(DDBPN, self).__init__()
scale = args.scale[0]
n0 = 128
nr = 32
self.depth = 6
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
initial = [
nn.Conv2d(args.n_colors, n0, 3, padding=1),
nn.PReLU(n0),
nn.Conv2d(n0, nr, 1),
nn.PReLU(nr)
]
self.initial = nn.Sequential(*initial)
self.upmodules = nn.ModuleList()
self.downmodules = nn.ModuleList()
channels = nr
for i in range(self.depth):
self.upmodules.append(
DenseProjection(channels, nr, scale, True, i > 1)
)
if i != 0:
channels += nr
channels = nr
for i in range(self.depth - 1):
self.downmodules.append(
DenseProjection(channels, nr, scale, False, i != 0)
)
channels += nr
reconstruction = [
nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1)
]
self.reconstruction = nn.Sequential(*reconstruction)
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
x = self.sub_mean(x)
x = self.initial(x)
h_list = []
l_list = []
for i in range(self.depth - 1):
if i == 0:
l = x
else:
l = torch.cat(l_list, dim=1)
h_list.append(self.upmodules[i](l))
l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
out = self.reconstruction(torch.cat(h_list, dim=1))
out = self.add_mean(out)
return out
================================================
FILE: src/model/edsr.py
================================================
from model import common
from model import attention
import torch.nn as nn
def make_model(args, parent=False):
if args.dilation:
from model import dilated
return EDSR(args, dilated.dilated_conv)
else:
return EDSR(args)
class EDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(EDSR, self).__init__()
n_resblock = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
#self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale);
# define head module
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
) for _ in range(n_resblock//2)
]
#m_body.append(self.msa)
for _ in range(n_resblock//2):
m_body.append( common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
nn.Conv2d(
n_feats, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
================================================
FILE: src/model/mdsr.py
================================================
from model import common
import torch.nn as nn
def make_model(args, parent=False):
return MDSR(args)
class MDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(MDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.scale_idx = 0
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
m_head = [conv(args.n_colors, n_feats, kernel_size)]
self.pre_process = nn.ModuleList([
nn.Sequential(
common.ResBlock(conv, n_feats, 5, act=act),
common.ResBlock(conv, n_feats, 5, act=act)
) for _ in args.scale
])
m_body = [
common.ResBlock(
conv, n_feats, kernel_size, act=act
) for _ in range(n_resblocks)
]
m_body.append(conv(n_feats, n_feats, kernel_size))
self.upsample = nn.ModuleList([
common.Upsampler(
conv, s, n_feats, act=False
) for s in args.scale
])
m_tail = [conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
x = self.pre_process[self.scale_idx](x)
res = self.body(x)
res += x
x = self.upsample[self.scale_idx](res)
x = self.tail(x)
x = self.add_mean(x)
return x
def set_scale(self, scale_idx):
self.scale_idx = scale_idx
================================================
FILE: src/model/mssr.py
================================================
from model import common
import torch.nn as nn
import torch
from model.attention import ContextualAttention,NonLocalAttention
def make_model(args, parent=False):
return MSSR(args)
class MultisourceProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(MultisourceProjection, self).__init__()
self.up_attention = ContextualAttention(scale=2)
self.down_attention = NonLocalAttention()
self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)
def forward(self,x):
down_map = self.upsample(self.down_attention(x))
up_map = self.up_attention(x)
err = self.encoder(up_map-down_map)
final_map = down_map + err
return final_map
class RecurrentProjection(nn.Module):
def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
super(RecurrentProjection, self).__init__()
self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)
self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
#self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])
self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])
self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())
def forward(self, x):
x_up = self.multi_source_projection_1(x)
x_down = self.down_sample_1(x_up)
error_up = self.error_encode_1(x-x_down)
h_estimate_1 = x_up + error_up
x_up_2 = self.multi_source_projection_2(h_estimate_1)
x_down_2 = self.down_sample_3(x_up_2)
error_up_2 = self.error_encode_2(x-x_down_2)
h_estimate_2 = x_up_2 + error_up_2
x_final = self.post_conv(self.down_sample_4(h_estimate_2))
return x_final, h_estimate_2
class MSSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(MSSR, self).__init__()
#n_convblock = args.n_convblocks
n_feats = args.n_feats
self.depth = args.depth
kernel_size = 3
scale = args.scale[0]
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),
common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]
# define multiple reconstruction module
self.body = RecurrentProjection(n_feats)
# define tail module
m_tail = [
nn.Conv2d(
n_feats*self.depth, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.tail = nn.Sequential(*m_tail)
def forward(self,input):
x = self.sub_mean(input)
x = self.head(x)
bag = []
for i in range(self.depth):
x, h_estimate = self.body(x)
bag.append(h_estimate)
h_feature = torch.cat(bag,dim=1)
h_final = self.tail(h_feature)
return self.add_mean(h_final)
================================================
FILE: src/model/nlsn.py
================================================
from model import common
from model import attention
import torch.nn as nn
def make_model(args, parent=False):
if args.dilation:
from model import dilated
return NLSN(args, dilated.dilated_conv)
else:
return NLSN(args)
class NLSN(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(NLSN, self).__init__()
n_resblock = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
scale = args.scale[0]
act = nn.ReLU(True)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
m_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
m_body = [attention.NonLocalSparseAttention(
channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)]
for i in range(n_resblock):
m_body.append( common.ResBlock(
conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
))
if (i+1)%8==0:
m_body.append(attention.NonLocalSparseAttention(
channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale))
m_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
m_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
nn.Conv2d(
n_feats, args.n_colors, kernel_size,
padding=(kernel_size//2)
)
]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*m_head)
self.body = nn.Sequential(*m_body)
self.tail = nn.Sequential(*m_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=True):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('tail') == -1:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('tail') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
================================================
FILE: src/model/rcan.py
================================================
## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks
## https://arxiv.org/abs/1807.02758
from model import common
import torch.nn as nn
import torch
def make_model(args, parent=False):
return RCAN(args)
## Channel Attention (CA) Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
#self.a = torch.nn.Parameter(torch.Tensor([0]))
#self.a.requires_grad=True
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
#res = self.body(x).mul(self.res_scale)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = []
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(RCAN, self).__init__()
self.a = nn.Parameter(torch.Tensor([0]))
self.a.requires_grad=True
n_resgroups = args.n_resgroups
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
reduction = args.reduction
scale = args.scale[0]
act = nn.ReLU(True)
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
# define head module
modules_head = [conv(args.n_colors, n_feats, kernel_size)]
# define body module
modules_body = [
ResidualGroup(
conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
# define tail module
modules_tail = [
common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, args.n_colors, kernel_size)]
self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
self.head = nn.Sequential(*modules_head)
self.body = nn.Sequential(*modules_body)
self.tail = nn.Sequential(*modules_tail)
def forward(self, x):
x = self.sub_mean(x)
x = self.head(x)
res = self.body(x)
res += x
x = self.tail(res)
x = self.add_mean(x)
return x
def load_state_dict(self, state_dict, strict=False):
own_state = self.state_dict()
for name, param in state_dict.items():
if name in own_state:
if isinstance(param, nn.Parameter):
param = param.data
try:
own_state[name].copy_(param)
except Exception:
if name.find('msa') or name.find('a') >= 0:
print('Replace pre-trained upsampler to new one...')
else:
raise RuntimeError('While copying the parameter named {}, '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}.'
.format(name, own_state[name].size(), param.size()))
elif strict:
if name.find('msa') == -1:
raise KeyError('unexpected key "{}" in state_dict'
.format(name))
if strict:
missing = set(own_state.keys()) - set(state_dict.keys())
if len(missing) > 0:
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
================================================
FILE: src/model/rdn.py
================================================
# Residual Dense Network for Image Super-Resolution
# https://arxiv.org/abs/1802.08797
from model import common
import torch
import torch.nn as nn
def make_model(args, parent=False):
return RDN(args)
class RDB_Conv(nn.Module):
def __init__(self, inChannels, growRate, kSize=3):
super(RDB_Conv, self).__init__()
Cin = inChannels
G = growRate
self.conv = nn.Sequential(*[
nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
nn.ReLU()
])
def forward(self, x):
out = self.conv(x)
return torch.cat((x, out), 1)
class RDB(nn.Module):
def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
super(RDB, self).__init__()
G0 = growRate0
G = growRate
C = nConvLayers
convs = []
for c in range(C):
convs.append(RDB_Conv(G0 + c*G, G))
self.convs = nn.Sequential(*convs)
# Local Feature Fusion
self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
def forward(self, x):
return self.LFF(self.convs(x)) + x
class RDN(nn.Module):
def __init__(self, args):
super(RDN, self).__init__()
r = args.scale[0]
G0 = args.G0
kSize = args.RDNkSize
# number of RDB blocks, conv layers, out channels
self.D, C, G = {
'A': (20, 6, 32),
'B': (16, 8, 64),
}[args.RDNconfig]
# Shallow feature extraction net
self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
# Redidual dense blocks and dense feature fusion
self.RDBs = nn.ModuleList()
for i in range(self.D):
self.RDBs.append(
RDB(growRate0 = G0, growRate = G, nConvLayers = C)
)
# Global Feature Fusion
self.GFF = nn.Sequential(*[
nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
])
# Up-sampling net
if r == 2 or r == 3:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(r),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
elif r == 4:
self.UPNet = nn.Sequential(*[
nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
nn.PixelShuffle(2),
nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
])
else:
raise ValueError("scale must be 2 or 3 or 4.")
def forward(self, x):
f__1 = self.SFENet1(x)
x = self.SFENet2(f__1)
RDBs_out = []
for i in range(self.D):
x = self.RDBs[i](x)
RDBs_out.append(x)
x = self.GFF(torch.cat(RDBs_out,1))
x += f__1
return self.UPNet(x)
================================================
FILE: src/model/utils/__init__.py
================================================
================================================
FILE: src/model/utils/tools.py
================================================
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def normalize(x):
return x.mul_(2).add_(-1)
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
================================================
FILE: src/model/vdsr.py
================================================
from model import common
import torch.nn as nn
import torch.nn.init as init
url = {
'r20f64': ''
}
def make_model(args, parent=False):
return VDSR(args)
class VDSR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(VDSR, self).__init__()
n_resblocks = args.n_resblocks
n_feats = args.n_feats
kernel_size = 3
self.url = url['r{}f{}'.format(n_resblocks, n_feats)]
self.sub_mean = common.MeanShift(args.rgb_range)
self.add_mean = common.MeanShift(args.rgb_range, sign=1)
def basic_block(in_channels, out_channels, act):
return common.BasicBlock(
conv, in_channels, out_channels, kernel_size,
bias=True, bn=False, act=act
)
# define body module
m_body = []
m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))
for _ in range(n_resblocks - 2):
m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))
m_body.append(basic_block(n_feats, args.n_colors, None))
self.body = nn.Sequential(*m_body)
def forward(self, x):
x = self.sub_mean(x)
res = self.body(x)
res += x
x = self.add_mean(res)
return x
================================================
FILE: src/option.py
================================================
import argparse
import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--debug', action='store_true',
help='Enables debug mode')
parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=18,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
parser.add_argument('--local_rank',type=int, default=0)
# Data specifications
parser.add_argument('--dir_data', type=str, default='../../../',
help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../Demo',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DIV2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='DIV2K',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-800/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=192,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--chunk_size',type=int,default=144,
help='attention bucket size')
parser.add_argument('--n_hashes',type=int,default=4,
help='number of hash rounds')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Model specifications
parser.add_argument('--model', default='EDSR',
help='model name')
parser.add_argument('--act', type=str, default='relu',
help='activation function')
parser.add_argument('--pre_train', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--n_resblocks', type=int, default=20,
help='number of residual blocks')
parser.add_argument('--n_feats', type=int, default=64,
help='number of feature maps')
parser.add_argument('--res_scale', type=float, default=1,
help='residual scaling')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Option for Residual dense network (RDN)
parser.add_argument('--G0', type=int, default=64,
help='default number of filters. (Use in RDN)')
parser.add_argument('--RDNkSize', type=int, default=3,
help='default kernel size. (Use in RDN)')
parser.add_argument('--RDNconfig', type=str, default='B',
help='parameters config of RDN. (Use in RDN)')
parser.add_argument('--depth', type=int, default=12,
help='number of residual groups')
# Option for Residual channel attention network (RCAN)
parser.add_argument('--n_resgroups', type=int, default=10,
help='number of residual groups')
parser.add_argument('--reduction', type=int, default=16,
help='number of feature maps reduction')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs', type=int, default=1000,
help='number of epochs to train')
parser.add_argument('--batch_size', type=int, default=16,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
help='set this option to test the model')
parser.add_argument('--gan_k', type=int, default=1,
help='k value for adversarial loss')
# Optimization specifications
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--decay', type=str, default='200',
help='learning rate decay type')
parser.add_argument('--gamma', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--betas', type=tuple, default=(0.9, 0.999),
help='ADAM beta')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--gclip', type=float, default=0,
help='gradient clipping threshold (0 = no clipping)')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e8',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='test',
help='file name to save')
parser.add_argument('--load', type=str, default='',
help='file name to load')
parser.add_argument('--resume', type=int, default=0,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=100,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', action='store_true',
help='save output results')
parser.add_argument('--save_gt', action='store_true',
help='save low-resolution and high-resolution images together')
args = parser.parse_args()
template.set_template(args)
args.scale = list(map(lambda x: int(x), args.scale.split('+')))
args.data_train = args.data_train.split('+')
args.data_test = args.data_test.split('+')
if args.epochs == 0:
args.epochs = 1e8
for arg in vars(args):
if vars(args)[arg] == 'True':
vars(args)[arg] = True
elif vars(args)[arg] == 'False':
vars(args)[arg] = False
================================================
FILE: src/template.py
================================================
def set_template(args):
# Set the templates here
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.decay = '100'
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.decay = '500'
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.decay = '150'
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
if args.template.find('VDSR') >= 0:
args.model = 'VDSR'
args.n_resblocks = 20
args.n_feats = 64
args.patch_size = 41
args.lr = 1e-1
================================================
FILE: src/trainer.py
================================================
import os
import math
from decimal import Decimal
import utility
import torch
import torch.nn.utils as utils
from tqdm import tqdm
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.loss = my_loss
self.optimizer = utility.make_optimizer(args, self.model)
if self.args.load != '':
self.optimizer.load(ckp.dir, epoch=len(ckp.log))
self.error_last = 1e8
def train(self):
self.loss.step()
epoch = self.optimizer.get_last_epoch() + 1
lr = self.optimizer.get_lr()
self.ckp.write_log(
'[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
)
self.loss.start_log()
self.model.train()
timer_data, timer_model = utility.timer(), utility.timer()
# TEMP
self.loader_train.dataset.set_scale(0)
for batch, (lr, hr, _,) in enumerate(self.loader_train):
lr, hr = self.prepare(lr, hr)
timer_data.hold()
timer_model.tic()
self.optimizer.zero_grad()
sr = self.model(lr, 0)
loss = self.loss(sr, hr)
loss.backward()
if self.args.gclip > 0:
utils.clip_grad_value_(
self.model.parameters(),
self.args.gclip
)
self.optimizer.step()
timer_model.hold()
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
(batch + 1) * self.args.batch_size,
len(self.loader_train.dataset),
self.loss.display_loss(batch),
timer_model.release(),
timer_data.release()))
timer_data.tic()
self.loss.end_log(len(self.loader_train))
self.error_last = self.loss.log[-1, -1]
self.optimizer.schedule()
def test(self):
torch.set_grad_enabled(False)
epoch = self.optimizer.get_last_epoch()
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(
torch.zeros(1, len(self.loader_test), len(self.scale))
)
self.model.eval()
timer_test = utility.timer()
if self.args.save_results: self.ckp.begin_background()
for idx_data, d in enumerate(self.loader_test):
for idx_scale, scale in enumerate(self.scale):
d.dataset.set_scale(idx_scale)
for lr, hr, filename in tqdm(d, ncols=80):
lr, hr = self.prepare(lr, hr)
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range)
save_list = [sr]
self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
sr, hr, scale, self.args.rgb_range, dataset=d
)
if self.args.save_gt:
save_list.extend([lr, hr])
if self.args.save_results:
self.ckp.save_results(d, filename[0], save_list, scale)
self.ckp.log[-1, idx_data, idx_scale] /= len(d)
best = self.ckp.log.max(0)
self.ckp.write_log(
'[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
d.dataset.name,
scale,
self.ckp.log[-1, idx_data, idx_scale],
best[0][idx_data, idx_scale],
best[1][idx_data, idx_scale] + 1
)
)
self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
self.ckp.write_log('Saving...')
if self.args.save_results:
self.ckp.end_background()
if not self.args.test_only:
self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.optimizer.get_last_epoch() + 1
return epoch >= self.args.epochs
================================================
FILE: src/utility.py
================================================
import os
import math
import time
import datetime
from multiprocessing import Process
from multiprocessing import Queue
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import imageio
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time()
def toc(self, restart=False):
diff = time.time() - self.t0
if restart: self.t0 = time.time()
return diff
def hold(self):
self.acc += self.toc()
def release(self):
ret = self.acc
self.acc = 0
return ret
def reset(self):
self.acc = 0
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if not args.load:
if not args.save:
args.save = now
self.dir = os.path.join('..', 'experiment', args.save)
else:
self.dir = os.path.join('..', 'experiment', args.load)
if os.path.exists(self.dir):
self.log = torch.load(self.get_path('psnr_log.pt'))
print('Continue from epoch {}...'.format(len(self.log)))
else:
args.load = ''
if args.reset:
os.system('rm -rf ' + self.dir)
args.load = ''
os.makedirs(self.dir, exist_ok=True)
os.makedirs(self.get_path('model'), exist_ok=True)
for d in args.data_test:
os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)
open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'
self.log_file = open(self.get_path('log.txt'), open_type)
with open(self.get_path('config.txt'), open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
self.n_processes = 8
def get_path(self, *subdir):
return os.path.join(self.dir, *subdir)
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.get_path('model'), epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
trainer.optimizer.save(self.dir)
torch.save(self.log, self.get_path('psnr_log.pt'))
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.get_path('log.txt'), 'a')
def done(self):
self.log_file.close()
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
for idx_data, d in enumerate(self.args.data_test):
label = 'SR on {}'.format(d)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_data, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig(self.get_path('test_{}.pdf'.format(d)))
plt.close(fig)
def begin_background(self):
self.queue = Queue()
def bg_target(queue):
while True:
if not queue.empty():
filename, tensor = queue.get()
if filename is None: break
imageio.imwrite(filename, tensor.numpy())
self.process = [
Process(target=bg_target, args=(self.queue,)) \
for _ in range(self.n_processes)
]
for p in self.process: p.start()
def end_background(self):
for _ in range(self.n_processes): self.queue.put((None, None))
while not self.queue.empty(): time.sleep(1)
for p in self.process: p.join()
def save_results(self, dataset, filename, save_list, scale):
if self.args.save_results:
filename = self.get_path(
'results-{}'.format(dataset.dataset.name),
'{}_x{}_'.format(filename, scale)
)
postfix = ('SR', 'LR', 'HR')
for v, p in zip(save_list, postfix):
normalized = v[0].mul(255 / self.args.rgb_range)
tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()
self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
if hr.nelement() == 1: return 0
diff = (sr - hr) / rgb_range
if dataset and dataset.dataset.benchmark:
shave = scale
if diff.size(1) > 1:
gray_coeffs = [65.738, 129.057, 25.064]
convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
diff = diff.mul(convert).sum(dim=1)
else:
shave = scale + 6
valid = diff[..., shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def make_optimizer(args, target):
'''
make optimizer and scheduler together
'''
# optimizer
trainable = filter(lambda x: x.requires_grad, target.parameters())
kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}
if args.optimizer == 'SGD':
optimizer_class = optim.SGD
kwargs_optimizer['momentum'] = args.momentum
elif args.optimizer == 'ADAM':
optimizer_class = optim.Adam
kwargs_optimizer['betas'] = args.betas
kwargs_optimizer['eps'] = args.epsilon
elif args.optimizer == 'RMSprop':
optimizer_class = optim.RMSprop
kwargs_optimizer['eps'] = args.epsilon
# scheduler
milestones = list(map(lambda x: int(x), args.decay.split('-')))
kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}
scheduler_class = lrs.MultiStepLR
class CustomOptimizer(optimizer_class):
def __init__(self, *args, **kwargs):
super(CustomOptimizer, self).__init__(*args, **kwargs)
def _register_scheduler(self, scheduler_class, **kwargs):
self.scheduler = scheduler_class(self, **kwargs)
def save(self, save_dir):
torch.save(self.state_dict(), self.get_dir(save_dir))
def load(self, load_dir, epoch=1):
self.load_state_dict(torch.load(self.get_dir(load_dir)))
if epoch > 1:
for _ in range(epoch): self.scheduler.step()
def get_dir(self, dir_path):
return os.path.join(dir_path, 'optimizer.pt')
def schedule(self):
self.scheduler.step()
def get_lr(self):
return self.scheduler.get_lr()[0]
def get_last_epoch(self):
return self.scheduler.last_epoch
optimizer = CustomOptimizer(trainable, **kwargs_optimizer)
optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)
return optimizer
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/tools.py
================================================
import os
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
def normalize(x):
return x.mul_(2).add_(-1)
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
================================================
FILE: src/videotester.py
================================================
import os
import math
import utility
from data import common
import torch
import cv2
from tqdm import tqdm
class VideoTester():
def __init__(self, args, my_model, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.model = my_model
self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
def test(self):
torch.set_grad_enabled(False)
self.ckp.write_log('\nEvaluation on video:')
self.model.eval()
timer_test = utility.timer()
for idx_scale, scale in enumerate(self.scale):
vidcap = cv2.VideoCapture(self.args.dir_demo)
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
vidwri = cv2.VideoWriter(
self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),
cv2.VideoWriter_fourcc(*'XVID'),
vidcap.get(cv2.CAP_PROP_FPS),
(
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))
)
)
tqdm_test = tqdm(range(total_frames), ncols=80)
for _ in tqdm_test:
success, lr = vidcap.read()
if not success: break
lr, = common.set_channel(lr, n_channels=self.args.n_colors)
lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)
lr, = self.prepare(lr.unsqueeze(0))
sr = self.model(lr, idx_scale)
sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)
normalized = sr * 255 / self.args.rgb_range
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
vidwri.write(ndarr)
vidcap.release()
vidwri.release()
self.ckp.write_log(
'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
)
torch.set_grad_enabled(True)
def prepare(self, *args):
device = torch.device('cpu' if self.args.cpu else 'cuda')
def _prepare(tensor):
if self.args.precision == 'half': tensor = tensor.half()
return tensor.to(device)
return [_prepare(a) for a in args]
gitextract_22vbrkxt/
├── README.md
└── src/
├── __init__.py
├── data/
│ ├── __init__.py
│ ├── benchmark.py
│ ├── common.py
│ ├── demo.py
│ ├── div2k.py
│ ├── div2kjpeg.py
│ ├── sr291.py
│ ├── srdata.py
│ └── video.py
├── dataloader.py
├── demo.sh
├── loss/
│ ├── __init__.py
│ ├── __loss__.py
│ ├── adversarial.py
│ ├── demo.sh
│ ├── discriminator.py
│ ├── hash.py
│ └── vgg.py
├── main.py
├── model/
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── attention.py
│ ├── common.py
│ ├── ddbpn.py
│ ├── edsr.py
│ ├── mdsr.py
│ ├── mssr.py
│ ├── nlsn.py
│ ├── rcan.py
│ ├── rdn.py
│ ├── utils/
│ │ ├── __init__.py
│ │ └── tools.py
│ └── vdsr.py
├── option.py
├── template.py
├── trainer.py
├── utility.py
├── utils/
│ ├── __init__.py
│ └── tools.py
└── videotester.py
SYMBOL INDEX (207 symbols across 33 files)
FILE: src/data/__init__.py
class MyConcatDataset (line 7) | class MyConcatDataset(ConcatDataset):
method __init__ (line 8) | def __init__(self, datasets):
method set_scale (line 12) | def set_scale(self, idx_scale):
class Data (line 16) | class Data:
method __init__ (line 17) | def __init__(self, args):
FILE: src/data/benchmark.py
class Benchmark (line 11) | class Benchmark(srdata.SRData):
method __init__ (line 12) | def __init__(self, args, name='', train=True, benchmark=True):
method _set_filesystem (line 17) | def _set_filesystem(self, dir_data):
FILE: src/data/common.py
function get_patch (line 8) | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=Fa...
function set_channel (line 34) | def set_channel(*args, n_channels=3):
function np2Tensor (line 49) | def np2Tensor(*args, rgb_range=255):
function augment (line 59) | def augment(*args, hflip=True, rot=True):
FILE: src/data/demo.py
class Demo (line 11) | class Demo(data.Dataset):
method __init__ (line 12) | def __init__(self, args, name='Demo', train=False, benchmark=False):
method __getitem__ (line 26) | def __getitem__(self, idx):
method __len__ (line 34) | def __len__(self):
method set_scale (line 37) | def set_scale(self, idx_scale):
FILE: src/data/div2k.py
class DIV2K (line 4) | class DIV2K(srdata.SRData):
method __init__ (line 5) | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
method _scan (line 20) | def _scan(self):
method _set_filesystem (line 27) | def _set_filesystem(self, dir_data):
FILE: src/data/div2kjpeg.py
class DIV2KJPEG (line 5) | class DIV2KJPEG(div2k.DIV2K):
method __init__ (line 6) | def __init__(self, args, name='', train=True, benchmark=False):
method _set_filesystem (line 12) | def _set_filesystem(self, dir_data):
FILE: src/data/sr291.py
class SR291 (line 3) | class SR291(srdata.SRData):
method __init__ (line 4) | def __init__(self, args, name='SR291', train=True, benchmark=False):
FILE: src/data/srdata.py
class SRData (line 13) | class SRData(data.Dataset):
method __init__ (line 14) | def __init__(self, args, name='', train=True, benchmark=False):
method _scan (line 68) | def _scan(self):
method _set_filesystem (line 84) | def _set_filesystem(self, dir_data):
method _check_and_load (line 91) | def _check_and_load(self, ext, img, f, verbose=True):
method __getitem__ (line 98) | def __getitem__(self, idx):
method __len__ (line 106) | def __len__(self):
method _get_index (line 112) | def _get_index(self, idx):
method _load_file (line 118) | def _load_file(self, idx):
method get_patch (line 135) | def get_patch(self, lr, hr):
method set_scale (line 152) | def set_scale(self, idx_scale):
FILE: src/data/video.py
class Video (line 12) | class Video(data.Dataset):
method __init__ (line 13) | def __init__(self, args, name='Video', train=False, benchmark=False):
method __getitem__ (line 27) | def __getitem__(self, idx):
method __len__ (line 39) | def __len__(self):
method set_scale (line 42) | def set_scale(self, idx_scale):
FILE: src/dataloader.py
function _ms_loop (line 22) | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, s...
class _MSDataLoaderIter (line 68) | class _MSDataLoaderIter(_DataLoaderIter):
method __init__ (line 70) | def __init__(self, loader):
class MSDataLoader (line 148) | class MSDataLoader(DataLoader):
method __init__ (line 150) | def __init__(self, cfg, *args, **kwargs):
method __iter__ (line 156) | def __iter__(self):
FILE: src/loss/__init__.py
class Loss (line 14) | class Loss(nn.modules.loss._Loss):
method __init__ (line 15) | def __init__(self, args, ckp):
method forward (line 67) | def forward(self, sr, hr):
method step (line 84) | def step(self):
method start_log (line 89) | def start_log(self):
method end_log (line 92) | def end_log(self, n_batches):
method display_loss (line 95) | def display_loss(self, batch):
method plot_loss (line 103) | def plot_loss(self, apath, epoch):
method get_loss_module (line 117) | def get_loss_module(self):
method save (line 123) | def save(self, apath):
method load (line 127) | def load(self, apath, cpu=False):
FILE: src/loss/adversarial.py
class Adversarial (line 12) | class Adversarial(nn.Module):
method __init__ (line 13) | def __init__(self, args, gan_type):
method forward (line 35) | def forward(self, fake, real):
method state_dict (line 95) | def state_dict(self, *args, **kwargs):
method bce (line 101) | def bce(self, real, fake):
FILE: src/loss/discriminator.py
class Discriminator (line 5) | class Discriminator(nn.Module):
method __init__ (line 9) | def __init__(self, args):
method forward (line 50) | def forward(self, x):
FILE: src/loss/hash.py
class HASH (line 8) | class HASH(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 12) | def forward(self, sr, qk, orders, hr, m=3):
FILE: src/loss/vgg.py
class VGG (line 8) | class VGG(nn.Module):
method __init__ (line 9) | def __init__(self, conv_index, rgb_range=1):
method forward (line 24) | def forward(self, sr, hr):
FILE: src/main.py
function main (line 13) | def main():
FILE: src/model/__init__.py
class Model (line 8) | class Model(nn.Module):
method __init__ (line 9) | def __init__(self, args, ckp):
method forward (line 38) | def forward(self, x, idx_scale):
method get_model (line 56) | def get_model(self):
method state_dict (line 62) | def state_dict(self, **kwargs):
method save (line 66) | def save(self, apath, epoch, is_best=False):
method load (line 84) | def load(self, apath, pre_train='.', resume=-1, cpu=False):
method forward_chop (line 114) | def forward_chop(self, x, shave=10, min_size=120000):
method forward_x8 (line 158) | def forward_x8(self, x, forward_function):
FILE: src/model/attention.py
class NonLocalSparseAttention (line 6) | class NonLocalSparseAttention(nn.Module):
method __init__ (line 7) | def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, ch...
method LSH (line 16) | def LSH(self, hash_buckets, x):
method add_adjacent_buckets (line 39) | def add_adjacent_buckets(self, x):
method forward (line 44) | def forward(self, input):
class NonLocalAttention (line 117) | class NonLocalAttention(nn.Module):
method __init__ (line 118) | def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=...
method forward (line 125) | def forward(self, input):
FILE: src/model/common.py
function batched_index_select (line 8) | def batched_index_select(values, indices):
function default_conv (line 12) | def default_conv(in_channels, out_channels, kernel_size,stride=1, bias=T...
class MeanShift (line 17) | class MeanShift(nn.Conv2d):
method __init__ (line 18) | def __init__(
class BasicBlock (line 29) | class BasicBlock(nn.Sequential):
method __init__ (line 30) | def __init__(
class ResBlock (line 42) | class ResBlock(nn.Module):
method __init__ (line 43) | def __init__(
method forward (line 59) | def forward(self, x):
class Upsampler (line 65) | class Upsampler(nn.Sequential):
method __init__ (line 66) | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
FILE: src/model/ddbpn.py
function make_model (line 10) | def make_model(args, parent=False):
function projection_conv (line 13) | def projection_conv(in_channels, out_channels, scale, up=True):
class DenseProjection (line 29) | class DenseProjection(nn.Module):
method __init__ (line 30) | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
method forward (line 55) | def forward(self, x):
class DDBPN (line 68) | class DDBPN(nn.Module):
method __init__ (line 69) | def __init__(self, args):
method forward (line 112) | def forward(self, x):
FILE: src/model/edsr.py
function make_model (line 5) | def make_model(args, parent=False):
class EDSR (line 12) | class EDSR(nn.Module):
method __init__ (line 13) | def __init__(self, args, conv=common.default_conv):
method forward (line 57) | def forward(self, x):
method load_state_dict (line 69) | def load_state_dict(self, state_dict, strict=True):
FILE: src/model/mdsr.py
function make_model (line 5) | def make_model(args, parent=False):
class MDSR (line 8) | class MDSR(nn.Module):
method __init__ (line 9) | def __init__(self, args, conv=common.default_conv):
method forward (line 52) | def forward(self, x):
method set_scale (line 66) | def set_scale(self, scale_idx):
FILE: src/model/mssr.py
function make_model (line 5) | def make_model(args, parent=False):
class MultisourceProjection (line 8) | class MultisourceProjection(nn.Module):
method __init__ (line 9) | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
method forward (line 16) | def forward(self,x):
class RecurrentProjection (line 25) | class RecurrentProjection(nn.Module):
method __init__ (line 26) | def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):
method forward (line 39) | def forward(self, x):
class MSSR (line 58) | class MSSR(nn.Module):
method __init__ (line 59) | def __init__(self, args, conv=common.default_conv):
method forward (line 94) | def forward(self,input):
FILE: src/model/nlsn.py
function make_model (line 5) | def make_model(args, parent=False):
class NLSN (line 13) | class NLSN(nn.Module):
method __init__ (line 14) | def __init__(self, args, conv=common.default_conv):
method forward (line 56) | def forward(self, x):
method load_state_dict (line 68) | def load_state_dict(self, state_dict, strict=True):
FILE: src/model/rcan.py
function make_model (line 7) | def make_model(args, parent=False):
class CALayer (line 11) | class CALayer(nn.Module):
method __init__ (line 12) | def __init__(self, channel, reduction=16):
method forward (line 27) | def forward(self, x):
class RCAB (line 33) | class RCAB(nn.Module):
method __init__ (line 34) | def __init__(
method forward (line 48) | def forward(self, x):
class ResidualGroup (line 55) | class ResidualGroup(nn.Module):
method __init__ (line 56) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
method forward (line 66) | def forward(self, x):
class RCAN (line 72) | class RCAN(nn.Module):
method __init__ (line 73) | def __init__(self, args, conv=common.default_conv):
method forward (line 111) | def forward(self, x):
method load_state_dict (line 122) | def load_state_dict(self, state_dict, strict=False):
FILE: src/model/rdn.py
function make_model (line 10) | def make_model(args, parent=False):
class RDB_Conv (line 13) | class RDB_Conv(nn.Module):
method __init__ (line 14) | def __init__(self, inChannels, growRate, kSize=3):
method forward (line 23) | def forward(self, x):
class RDB (line 27) | class RDB(nn.Module):
method __init__ (line 28) | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
method forward (line 42) | def forward(self, x):
class RDN (line 45) | class RDN(nn.Module):
method __init__ (line 46) | def __init__(self, args):
method forward (line 93) | def forward(self, x):
FILE: src/model/utils/tools.py
function normalize (line 8) | def normalize(x):
function same_padding (line 11) | def same_padding(images, ksizes, strides, rates):
function extract_image_patches (line 30) | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
function reduce_mean (line 59) | def reduce_mean(x, axis=None, keepdim=False):
function reduce_std (line 67) | def reduce_std(x, axis=None, keepdim=False):
function reduce_sum (line 75) | def reduce_sum(x, axis=None, keepdim=False):
FILE: src/model/vdsr.py
function make_model (line 10) | def make_model(args, parent=False):
class VDSR (line 13) | class VDSR(nn.Module):
method __init__ (line 14) | def __init__(self, args, conv=common.default_conv):
method forward (line 39) | def forward(self, x):
FILE: src/template.py
function set_template (line 1) | def set_template(args):
FILE: src/trainer.py
class Trainer (line 11) | class Trainer():
method __init__ (line 12) | def __init__(self, args, loader, my_model, my_loss, ckp):
method train (line 28) | def train(self):
method test (line 74) | def test(self):
method prepare (line 131) | def prepare(self, *args):
method terminate (line 139) | def terminate(self):
FILE: src/utility.py
class timer (line 19) | class timer():
method __init__ (line 20) | def __init__(self):
method tic (line 24) | def tic(self):
method toc (line 27) | def toc(self, restart=False):
method hold (line 32) | def hold(self):
method release (line 35) | def release(self):
method reset (line 41) | def reset(self):
class checkpoint (line 44) | class checkpoint():
method __init__ (line 45) | def __init__(self, args):
method get_path (line 82) | def get_path(self, *subdir):
method save (line 85) | def save(self, trainer, epoch, is_best=False):
method add_log (line 94) | def add_log(self, log):
method write_log (line 97) | def write_log(self, log, refresh=False):
method done (line 104) | def done(self):
method plot_psnr (line 107) | def plot_psnr(self, epoch):
method begin_background (line 126) | def begin_background(self):
method end_background (line 143) | def end_background(self):
method save_results (line 148) | def save_results(self, dataset, filename, save_list, scale):
function quantize (line 161) | def quantize(img, rgb_range):
function calc_psnr (line 165) | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
function make_optimizer (line 183) | def make_optimizer(args, target):
FILE: src/utils/tools.py
function normalize (line 8) | def normalize(x):
function same_padding (line 11) | def same_padding(images, ksizes, strides, rates):
function extract_image_patches (line 30) | def extract_image_patches(images, ksizes, strides, rates, padding='same'):
function reduce_mean (line 59) | def reduce_mean(x, axis=None, keepdim=False):
function reduce_std (line 67) | def reduce_std(x, axis=None, keepdim=False):
function reduce_sum (line 75) | def reduce_sum(x, axis=None, keepdim=False):
FILE: src/videotester.py
class VideoTester (line 12) | class VideoTester():
method __init__ (line 13) | def __init__(self, args, my_model, ckp):
method test (line 22) | def test(self):
method prepare (line 65) | def prepare(self, *args):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (126K chars).
[
{
"path": "README.md",
"chars": 4816,
"preview": "# Image Super-Resolution with Non-Local Sparse Attention \nThis repository is for NLSN introduced in the following paper "
},
{
"path": "src/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/data/__init__.py",
"chars": 1949,
"preview": "from importlib import import_module\n#from dataloader import MSDataLoader\nfrom torch.utils.data import dataloader\nfrom to"
},
{
"path": "src/data/benchmark.py",
"chars": 703,
"preview": "import os\n\nfrom data import common\nfrom data import srdata\n\nimport numpy as np\n\nimport torch\nimport torch.utils.data as "
},
{
"path": "src/data/common.py",
"chars": 1786,
"preview": "import random\n\nimport numpy as np\nimport skimage.color as sc\n\nimport torch\n\ndef get_patch(*args, patch_size=96, scale=2,"
},
{
"path": "src/data/demo.py",
"chars": 1075,
"preview": "import os\n\nfrom data import common\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\ncla"
},
{
"path": "src/data/div2k.py",
"chars": 1134,
"preview": "import os\nfrom data import srdata\n\nclass DIV2K(srdata.SRData):\n def __init__(self, args, name='DIV2K', train=True, be"
},
{
"path": "src/data/div2kjpeg.py",
"chars": 675,
"preview": "import os\nfrom data import srdata\nfrom data import div2k\n\nclass DIV2KJPEG(div2k.DIV2K):\n def __init__(self, args, nam"
},
{
"path": "src/data/sr291.py",
"chars": 180,
"preview": "from data import srdata\n\nclass SR291(srdata.SRData):\n def __init__(self, args, name='SR291', train=True, benchmark=Fa"
},
{
"path": "src/data/srdata.py",
"chars": 5343,
"preview": "import os\nimport glob\nimport random\nimport pickle\n\nfrom data import common\n\nimport numpy as np\nimport imageio\nimport tor"
},
{
"path": "src/data/video.py",
"chars": 1207,
"preview": "import os\n\nfrom data import common\n\nimport cv2\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data a"
},
{
"path": "src/dataloader.py",
"chars": 5259,
"preview": "import threading\nimport random\n\nimport torch\nimport torch.multiprocessing as multiprocessing\nfrom torch.utils.data impor"
},
{
"path": "src/demo.sh",
"chars": 605,
"preview": "#!/bin/bash\n#Train x2\npython main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_mod"
},
{
"path": "src/loss/__init__.py",
"chars": 4769,
"preview": "import os\r\nfrom importlib import import_module\r\n\r\nimport matplotlib\r\nmatplotlib.use('Agg')\r\nimport matplotlib.pyplot as "
},
{
"path": "src/loss/__loss__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/loss/adversarial.py",
"chars": 4393,
"preview": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\n"
},
{
"path": "src/loss/demo.sh",
"chars": 0,
"preview": ""
},
{
"path": "src/loss/discriminator.py",
"chars": 1595,
"preview": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n '''\n output is not normalize"
},
{
"path": "src/loss/hash.py",
"chars": 944,
"preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
},
{
"path": "src/loss/vgg.py",
"chars": 1106,
"preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
},
{
"path": "src/main.py",
"chars": 1028,
"preview": "import torch\n\nimport utility\nimport data\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\n\nt"
},
{
"path": "src/model/LICENSE",
"chars": 1069,
"preview": "MIT License\n\nCopyright (c) 2018 Sanghyun Son\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
},
{
"path": "src/model/README.md",
"chars": 9652,
"preview": "# EDSR-PyTorch\n\n\nThis repository is an official PyTorch implementation of the paper **\"Enhanced Deep "
},
{
"path": "src/model/__init__.py",
"chars": 6249,
"preview": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\nc"
},
{
"path": "src/model/attention.py",
"chars": 6979,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom model import common\n\nclass NonLocalSparseAttenti"
},
{
"path": "src/model/common.py",
"chars": 2964,
"preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef batched_index_select(values, indic"
},
{
"path": "src/model/ddbpn.py",
"chars": 3629,
"preview": "# Deep Back-Projection Networks For Super-Resolution\n# https://arxiv.org/abs/1803.02735\n\nfrom model import common\n\nimpor"
},
{
"path": "src/model/edsr.py",
"chars": 2983,
"preview": "from model import common\nfrom model import attention\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n if a"
},
{
"path": "src/model/mdsr.py",
"chars": 1837,
"preview": "from model import common\n\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n return MDSR(args)\n\nclass MDSR(n"
},
{
"path": "src/model/mssr.py",
"chars": 4174,
"preview": "from model import common\nimport torch.nn as nn\nimport torch\nfrom model.attention import ContextualAttention,NonLocalAtte"
},
{
"path": "src/model/nlsn.py",
"chars": 3038,
"preview": "from model import common\nfrom model import attention\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n if a"
},
{
"path": "src/model/rcan.py",
"chars": 5378,
"preview": "## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks\n## https://arxiv.org/abs/1807.02"
},
{
"path": "src/model/rdn.py",
"chars": 3202,
"preview": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport"
},
{
"path": "src/model/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/model/utils/tools.py",
"chars": 2777,
"preview": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nimport torch.nn.functional as F\n\ndef normalize(x):\n "
},
{
"path": "src/model/vdsr.py",
"chars": 1275,
"preview": "from model import common\n\nimport torch.nn as nn\nimport torch.nn.init as init\n\nurl = {\n 'r20f64': ''\n}\n\ndef make_model"
},
{
"path": "src/option.py",
"chars": 7855,
"preview": "import argparse\nimport template\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--d"
},
{
"path": "src/template.py",
"chars": 1312,
"preview": "def set_template(args):\n # Set the templates here\n if args.template.find('jpeg') >= 0:\n args.data_train = '"
},
{
"path": "src/trainer.py",
"chars": 4820,
"preview": "import os\nimport math\nfrom decimal import Decimal\n\nimport utility\n\nimport torch\nimport torch.nn.utils as utils\nfrom tqdm"
},
{
"path": "src/utility.py",
"chars": 7480,
"preview": "import os\nimport math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n"
},
{
"path": "src/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/utils/tools.py",
"chars": 2777,
"preview": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nimport torch.nn.functional as F\n\ndef normalize(x):\n "
},
{
"path": "src/videotester.py",
"chars": 2280,
"preview": "import os\nimport math\n\nimport utility\nfrom data import common\n\nimport torch\nimport cv2\n\nfrom tqdm import tqdm\n\nclass Vid"
}
]
About this extraction
This page contains the full source code of the HarukiYqM/Non-Local-Sparse-Attention GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (117.5 KB), approximately 30.9k tokens, and a symbol index with 207 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.