[
  {
    "path": "README.md",
    "content": "# Image Super-Resolution with Non-Local Sparse Attention \nThis repository is for NLSN introduced in the following paper \"Image Super-Resolution with Non-Local Sparse Attention\", CVPR2021, [[Link]](https://openaccess.thecvf.com/content/CVPR2021/papers/Mei_Image_Super-Resolution_With_Non-Local_Sparse_Attention_CVPR_2021_paper.pdf) \n\n\nThe code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and test on Ubuntu 18.04 environment (Python3.6, PyTorch >= 1.1.0) with V100 GPUs. \n## Contents\n1. [Introduction](#introduction)\n2. [Train](#train)\n3. [Test](#test)\n5. [Citation](#citation)\n6. [Acknowledgements](#acknowledgements)\n\n## Introduction\n\nBoth Non-Local (NL) operation and sparse representa-tion are crucial for Single Image Super-Resolution (SISR).In this paper, we investigate their combinations and proposea novel Non-Local Sparse Attention (NLSA) with dynamicsparse attention pattern. NLSA is designed to retain long-range modeling capability from NL operation while enjoying robustness and high-efficiency of sparse representation.Specifically, NLSA rectifies non-local attention with spherical locality sensitive hashing (LSH) that partitions the input space into hash buckets of related features. For everyquery signal, NLSA assigns a bucket to it and only computes attention within the bucket. The resulting sparse attention prevents the model from attending to locations thatare noisy and less-informative, while reducing the computa-tional cost from quadratic to asymptotic linear with respectto the spatial size. Extensive experiments validate the effectiveness and efficiency of NLSA. With a few non-local sparseattention modules, our architecture, called non-local sparsenetwork (NLSN), reaches state-of-the-art performance forSISR quantitatively and qualitatively.\n\n![Non-Local Sparse Attention](/Figs/Attention.png)\n\nNon-Local Sparse Attention.\n\n![NLSN](/Figs/NLSN.png)\n\nNon-Local Sparse Network.\n\n## Train\n### Prepare training data \n\n1. Download DIV2K training data (800 training + 100 validtion images) from [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K/) or [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar).\n\n2. Specify '--dir_data' based on the HR and LR images path. \n\nFor more informaiton, please refer to [EDSR(PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch).\n\n### Begin to train\n\n1. (optional) Download pretrained models for our paper.\n\n    Pre-trained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing) \n\n2. Cd to 'src', run the following script to train models.\n\n    **Example command is in the file 'demo.sh'.**\n\n    ```bash\n    # Example X2 SR\n    python main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K\n\n    ```\n\n## Test\n### Quick start\n1. Download benchmark datasets from [SNU_CVLab](https://cv.snu.ac.kr/research/EDSR/benchmark.tar)\n\n1. (optional) Download pretrained models for our paper.\n\n    All the models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1zz2a1ih3euzuH3HvWDN-uSki3USym9Cq?usp=sharing) \n\n2. Cd to 'src', run the following scripts.\n\n    **Example command is in the file 'demo.sh'.**\n\n    ```bash\n    # No self-ensemble: NLSN\n    # Example X2 SR\n    python main.py --dir_data ../../ --model NLSN  --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1  --pre_train model_x2.pt --test_only\n    ```\n\n## Citation\nIf you find the code helpful in your resarch or work, please cite the following papers.\n```\n@InProceedings{Mei_2021_CVPR,\n    author    = {Mei, Yiqun and Fan, Yuchen and Zhou, Yuqian},\n    title     = {Image Super-Resolution With Non-Local Sparse Attention},\n    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},\n    month     = {June},\n    year      = {2021},\n    pages     = {3517-3526}\n}\n@InProceedings{Lim_2017_CVPR_Workshops,\n  author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},\n  title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},\n  booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},\n  month = {July},\n  year = {2017}\n}\n\n```\n## Acknowledgements\nThis code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch) and [reformer-pytorch](https://github.com/lucidrains/reformer-pytorch). We thank the authors for sharing their codes.\n"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/data/__init__.py",
    "content": "from importlib import import_module\n#from dataloader import MSDataLoader\nfrom torch.utils.data import dataloader\nfrom torch.utils.data import ConcatDataset\n\n# This is a simple wrapper function for ConcatDataset\nclass MyConcatDataset(ConcatDataset):\n    def __init__(self, datasets):\n        super(MyConcatDataset, self).__init__(datasets)\n        self.train = datasets[0].train\n\n    def set_scale(self, idx_scale):\n        for d in self.datasets:\n            if hasattr(d, 'set_scale'): d.set_scale(idx_scale)\n\nclass Data:\n    def __init__(self, args):\n        self.loader_train = None\n        if not args.test_only:\n            datasets = []\n            for d in args.data_train:\n                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'\n                m = import_module('data.' + module_name.lower())\n                datasets.append(getattr(m, module_name)(args, name=d))\n\n            self.loader_train = dataloader.DataLoader(\n                MyConcatDataset(datasets),\n                batch_size=args.batch_size,\n                shuffle=True,\n                pin_memory=not args.cpu,\n                num_workers=args.n_threads,\n            )\n\n        self.loader_test = []\n        for d in args.data_test:\n            if d in ['Set5', 'Set14', 'B100', 'Urban100']:\n                m = import_module('data.benchmark')\n                testset = getattr(m, 'Benchmark')(args, train=False, name=d)\n            else:\n                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'\n                m = import_module('data.' + module_name.lower())\n                testset = getattr(m, module_name)(args, train=False, name=d)\n\n            self.loader_test.append(\n                dataloader.DataLoader(\n                    testset,\n                    batch_size=1,\n                    shuffle=False,\n                    pin_memory=not args.cpu,\n                    num_workers=args.n_threads,\n                )\n            )\n"
  },
  {
    "path": "src/data/benchmark.py",
    "content": "import os\n\nfrom data import common\nfrom data import srdata\n\nimport numpy as np\n\nimport torch\nimport torch.utils.data as data\n\nclass Benchmark(srdata.SRData):\n    def __init__(self, args, name='', train=True, benchmark=True):\n        super(Benchmark, self).__init__(\n            args, name=name, train=train, benchmark=True\n        )\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, 'benchmark', self.name)\n        self.dir_hr = os.path.join(self.apath, 'HR')\n        if self.input_large:\n            self.dir_lr = os.path.join(self.apath, 'LR_bicubicL')\n        else:\n            self.dir_lr = os.path.join(self.apath, 'LR_bicubic')\n        self.ext = ('', '.png')\n\n"
  },
  {
    "path": "src/data/common.py",
    "content": "import random\n\nimport numpy as np\nimport skimage.color as sc\n\nimport torch\n\ndef get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):\n    ih, iw = args[0].shape[:2]\n\n    if not input_large:\n        p = scale if multi else 1\n        tp = p * patch_size\n        ip = tp // scale\n    else:\n        tp = patch_size\n        ip = patch_size\n\n    ix = random.randrange(0, iw - ip + 1)\n    iy = random.randrange(0, ih - ip + 1)\n\n    if not input_large:\n        tx, ty = scale * ix, scale * iy\n    else:\n        tx, ty = ix, iy\n\n    ret = [\n        args[0][iy:iy + ip, ix:ix + ip, :],\n        *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]\n    ]\n\n    return ret\n\ndef set_channel(*args, n_channels=3):\n    def _set_channel(img):\n        if img.ndim == 2:\n            img = np.expand_dims(img, axis=2)\n\n        c = img.shape[2]\n        if n_channels == 1 and c == 3:\n            img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)\n        elif n_channels == 3 and c == 1:\n            img = np.concatenate([img] * n_channels, 2)\n\n        return img\n\n    return [_set_channel(a) for a in args]\n\ndef np2Tensor(*args, rgb_range=255):\n    def _np2Tensor(img):\n        np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))\n        tensor = torch.from_numpy(np_transpose).float()\n        tensor.mul_(rgb_range / 255)\n\n        return tensor\n\n    return [_np2Tensor(a) for a in args]\n\ndef augment(*args, hflip=True, rot=True):\n    hflip = hflip and random.random() < 0.5\n    vflip = rot and random.random() < 0.5\n    rot90 = rot and random.random() < 0.5\n\n    def _augment(img):\n        if hflip: img = img[:, ::-1, :]\n        if vflip: img = img[::-1, :, :]\n        if rot90: img = img.transpose(1, 0, 2)\n        \n        return img\n\n    return [_augment(a) for a in args]\n\n"
  },
  {
    "path": "src/data/demo.py",
    "content": "import os\n\nfrom data import common\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\nclass Demo(data.Dataset):\n    def __init__(self, args, name='Demo', train=False, benchmark=False):\n        self.args = args\n        self.name = name\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.train = False\n        self.benchmark = benchmark\n\n        self.filelist = []\n        for f in os.listdir(args.dir_demo):\n            if f.find('.png') >= 0 or f.find('.jp') >= 0:\n                self.filelist.append(os.path.join(args.dir_demo, f))\n        self.filelist.sort()\n\n    def __getitem__(self, idx):\n        filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]\n        lr = imageio.imread(self.filelist[idx])\n        lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n        lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n\n        return lr_t, -1, filename\n\n    def __len__(self):\n        return len(self.filelist)\n\n    def set_scale(self, idx_scale):\n        self.idx_scale = idx_scale\n\n"
  },
  {
    "path": "src/data/div2k.py",
    "content": "import os\nfrom data import srdata\n\nclass DIV2K(srdata.SRData):\n    def __init__(self, args, name='DIV2K', train=True, benchmark=False):\n        data_range = [r.split('-') for r in args.data_range.split('/')]\n        if train:\n            data_range = data_range[0]\n        else:\n            if args.test_only and len(data_range) == 1:\n                data_range = data_range[0]\n            else:\n                data_range = data_range[1]\n\n        self.begin, self.end = list(map(lambda x: int(x), data_range))\n        super(DIV2K, self).__init__(\n            args, name=name, train=train, benchmark=benchmark\n        )\n\n    def _scan(self):\n        names_hr, names_lr = super(DIV2K, self)._scan()\n        names_hr = names_hr[self.begin - 1:self.end]\n        names_lr = [n[self.begin - 1:self.end] for n in names_lr]\n\n        return names_hr, names_lr\n\n    def _set_filesystem(self, dir_data):\n        super(DIV2K, self)._set_filesystem(dir_data)\n        self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')\n        self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic')\n        if self.input_large: self.dir_lr += 'L'\n\n"
  },
  {
    "path": "src/data/div2kjpeg.py",
    "content": "import os\nfrom data import srdata\nfrom data import div2k\n\nclass DIV2KJPEG(div2k.DIV2K):\n    def __init__(self, args, name='', train=True, benchmark=False):\n        self.q_factor = int(name.replace('DIV2K-Q', ''))\n        super(DIV2KJPEG, self).__init__(\n            args, name=name, train=train, benchmark=benchmark\n        )\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, 'DIV2K')\n        self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')\n        self.dir_lr = os.path.join(\n            self.apath, 'DIV2K_Q{}'.format(self.q_factor)\n        )\n        if self.input_large: self.dir_lr += 'L'\n        self.ext = ('.png', '.jpg')\n\n"
  },
  {
    "path": "src/data/sr291.py",
    "content": "from data import srdata\n\nclass SR291(srdata.SRData):\n    def __init__(self, args, name='SR291', train=True, benchmark=False):\n        super(SR291, self).__init__(args, name=name)\n\n"
  },
  {
    "path": "src/data/srdata.py",
    "content": "import os\nimport glob\nimport random\nimport pickle\n\nfrom data import common\n\nimport numpy as np\nimport imageio\nimport torch\nimport torch.utils.data as data\n\nclass SRData(data.Dataset):\n    def __init__(self, args, name='', train=True, benchmark=False):\n        self.args = args\n        self.name = name\n        self.train = train\n        self.split = 'train' if train else 'test'\n        self.do_eval = True\n        self.benchmark = benchmark\n        self.input_large = (args.model == 'VDSR')\n        self.scale = args.scale\n        self.idx_scale = 0\n        \n        self._set_filesystem(args.dir_data)\n        if args.ext.find('img') < 0:\n            path_bin = os.path.join(self.apath, 'bin')\n            os.makedirs(path_bin, exist_ok=True)\n\n        list_hr, list_lr = self._scan()\n        if args.ext.find('img') >= 0 or benchmark:\n            self.images_hr, self.images_lr = list_hr, list_lr\n        elif args.ext.find('sep') >= 0:\n            os.makedirs(\n                self.dir_hr.replace(self.apath, path_bin),\n                exist_ok=True\n            )\n            for s in self.scale:\n                os.makedirs(\n                    os.path.join(\n                        self.dir_lr.replace(self.apath, path_bin),\n                        'X{}'.format(s)\n                    ),\n                    exist_ok=True\n                )\n            \n            self.images_hr, self.images_lr = [], [[] for _ in self.scale]\n            for h in list_hr:\n                b = h.replace(self.apath, path_bin)\n                b = b.replace(self.ext[0], '.pt')\n                self.images_hr.append(b)\n                self._check_and_load(args.ext, h, b, verbose=True) \n            for i, ll in enumerate(list_lr):\n                for l in ll:\n                    b = l.replace(self.apath, path_bin)\n                    b = b.replace(self.ext[1], '.pt')\n                    self.images_lr[i].append(b)\n                    self._check_and_load(args.ext, l, b, verbose=True) \n        if train:\n            n_patches = args.batch_size * args.test_every\n            n_images = len(args.data_train) * len(self.images_hr)\n            if n_images == 0:\n                self.repeat = 0\n            else:\n                self.repeat = max(n_patches // n_images, 1)\n\n    # Below functions as used to prepare images\n    def _scan(self):\n        names_hr = sorted(\n            glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))\n        )\n        names_lr = [[] for _ in self.scale]\n        for f in names_hr:\n            filename, _ = os.path.splitext(os.path.basename(f))\n            for si, s in enumerate(self.scale):\n                names_lr[si].append(os.path.join(\n                    self.dir_lr, 'X{}/{}x{}{}'.format(\n                        s, filename, s, self.ext[1]\n                    )\n                ))\n\n        return names_hr, names_lr\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, self.name)\n        self.dir_hr = os.path.join(self.apath, 'HR')\n        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')\n        if self.input_large: self.dir_lr += 'L'\n        self.ext = ('.png', '.png')\n\n    def _check_and_load(self, ext, img, f, verbose=True):\n        if not os.path.isfile(f) or ext.find('reset') >= 0:\n            if verbose:\n                print('Making a binary: {}'.format(f))\n            with open(f, 'wb') as _f:\n                pickle.dump(imageio.imread(img), _f)\n\n    def __getitem__(self, idx):\n        lr, hr, filename = self._load_file(idx)\n        pair = self.get_patch(lr, hr)\n        pair = common.set_channel(*pair, n_channels=self.args.n_colors)\n        pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)\n\n        return pair_t[0], pair_t[1], filename\n\n    def __len__(self):\n        if self.train:\n            return len(self.images_hr) * self.repeat\n        else:\n            return len(self.images_hr)\n\n    def _get_index(self, idx):\n        if self.train:\n            return idx % len(self.images_hr)\n        else:\n            return idx\n\n    def _load_file(self, idx):\n        idx = self._get_index(idx)\n        f_hr = self.images_hr[idx]\n        f_lr = self.images_lr[self.idx_scale][idx]\n\n        filename, _ = os.path.splitext(os.path.basename(f_hr))\n        if self.args.ext == 'img' or self.benchmark:\n            hr = imageio.imread(f_hr)\n            lr = imageio.imread(f_lr)\n        elif self.args.ext.find('sep') >= 0:\n            with open(f_hr, 'rb') as _f:\n                hr = pickle.load(_f)\n            with open(f_lr, 'rb') as _f:\n                lr = pickle.load(_f)\n\n        return lr, hr, filename\n\n    def get_patch(self, lr, hr):\n        scale = self.scale[self.idx_scale]\n        if self.train:\n            lr, hr = common.get_patch(\n                lr, hr,\n                patch_size=self.args.patch_size,\n                scale=scale,\n                multi=(len(self.scale) > 1),\n                input_large=self.input_large\n            )\n            if not self.args.no_augment: lr, hr = common.augment(lr, hr)\n        else:\n            ih, iw = lr.shape[:2]\n            hr = hr[0:ih * scale, 0:iw * scale]\n\n        return lr, hr\n\n    def set_scale(self, idx_scale):\n        if not self.input_large:\n            self.idx_scale = idx_scale\n        else:\n            self.idx_scale = random.randint(0, len(self.scale) - 1)\n\n"
  },
  {
    "path": "src/data/video.py",
    "content": "import os\n\nfrom data import common\n\nimport cv2\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\nclass Video(data.Dataset):\n    def __init__(self, args, name='Video', train=False, benchmark=False):\n        self.args = args\n        self.name = name\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.train = False\n        self.do_eval = False\n        self.benchmark = benchmark\n\n        self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))\n        self.vidcap = cv2.VideoCapture(args.dir_demo)\n        self.n_frames = 0\n        self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n\n    def __getitem__(self, idx):\n        success, lr = self.vidcap.read()\n        if success:\n            self.n_frames += 1\n            lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n            lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n\n            return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)\n        else:\n            vidcap.release()\n            return None\n\n    def __len__(self):\n        return self.total_frames\n\n    def set_scale(self, idx_scale):\n        self.idx_scale = idx_scale\n\n"
  },
  {
    "path": "src/dataloader.py",
    "content": "import threading\nimport random\n\nimport torch\nimport torch.multiprocessing as multiprocessing\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import SequentialSampler\nfrom torch.utils.data import RandomSampler\nfrom torch.utils.data import BatchSampler\nfrom torch.utils.data import _utils\nfrom torch.utils.data.dataloader import _DataLoaderIter\n\nfrom torch.utils.data._utils import collate\nfrom torch.utils.data._utils import signal_handling\nfrom torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL\nfrom torch.utils.data._utils import ExceptionWrapper\nfrom torch.utils.data._utils import IS_WINDOWS\nfrom torch.utils.data._utils.worker import ManagerWatchdog\n\nfrom torch._six import queue\n\ndef _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):\n    try:\n        collate._use_shared_memory = True\n        signal_handling._set_worker_signal_handlers()\n\n        torch.set_num_threads(1)\n        random.seed(seed)\n        torch.manual_seed(seed)\n\n        data_queue.cancel_join_thread()\n\n        if init_fn is not None:\n            init_fn(worker_id)\n\n        watchdog = ManagerWatchdog()\n\n        while watchdog.is_alive():\n            try:\n                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)\n            except queue.Empty:\n                continue\n\n            if r is None:\n                assert done_event.is_set()\n                return\n            elif done_event.is_set():\n                continue\n\n            idx, batch_indices = r\n            try:\n                idx_scale = 0\n                if len(scale) > 1 and dataset.train:\n                    idx_scale = random.randrange(0, len(scale))\n                    dataset.set_scale(idx_scale)\n\n                samples = collate_fn([dataset[i] for i in batch_indices])\n                samples.append(idx_scale)\n            except Exception:\n                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))\n            else:\n                data_queue.put((idx, samples))\n                del samples\n\n    except KeyboardInterrupt:\n        pass\n\nclass _MSDataLoaderIter(_DataLoaderIter):\n\n    def __init__(self, loader):\n        self.dataset = loader.dataset\n        self.scale = loader.scale\n        self.collate_fn = loader.collate_fn\n        self.batch_sampler = loader.batch_sampler\n        self.num_workers = loader.num_workers\n        self.pin_memory = loader.pin_memory and torch.cuda.is_available()\n        self.timeout = loader.timeout\n\n        self.sample_iter = iter(self.batch_sampler)\n\n        base_seed = torch.LongTensor(1).random_().item()\n\n        if self.num_workers > 0:\n            self.worker_init_fn = loader.worker_init_fn\n            self.worker_queue_idx = 0\n            self.worker_result_queue = multiprocessing.Queue()\n            self.batches_outstanding = 0\n            self.worker_pids_set = False\n            self.shutdown = False\n            self.send_idx = 0\n            self.rcvd_idx = 0\n            self.reorder_dict = {}\n            self.done_event = multiprocessing.Event()\n\n            base_seed = torch.LongTensor(1).random_()[0]\n\n            self.index_queues = []\n            self.workers = []\n            for i in range(self.num_workers):\n                index_queue = multiprocessing.Queue()\n                index_queue.cancel_join_thread()\n                w = multiprocessing.Process(\n                    target=_ms_loop,\n                    args=(\n                        self.dataset,\n                        index_queue,\n                        self.worker_result_queue,\n                        self.done_event,\n                        self.collate_fn,\n                        self.scale,\n                        base_seed + i,\n                        self.worker_init_fn,\n                        i\n                    )\n                )\n                w.daemon = True\n                w.start()\n                self.index_queues.append(index_queue)\n                self.workers.append(w)\n\n            if self.pin_memory:\n                self.data_queue = queue.Queue()\n                pin_memory_thread = threading.Thread(\n                    target=_utils.pin_memory._pin_memory_loop,\n                    args=(\n                        self.worker_result_queue,\n                        self.data_queue,\n                        torch.cuda.current_device(),\n                        self.done_event\n                    )\n                )\n                pin_memory_thread.daemon = True\n                pin_memory_thread.start()\n                self.pin_memory_thread = pin_memory_thread\n            else:\n                self.data_queue = self.worker_result_queue\n\n            _utils.signal_handling._set_worker_pids(\n                id(self), tuple(w.pid for w in self.workers)\n            )\n            _utils.signal_handling._set_SIGCHLD_handler()\n            self.worker_pids_set = True\n\n            for _ in range(2 * self.num_workers):\n                self._put_indices()\n\n\nclass MSDataLoader(DataLoader):\n\n    def __init__(self, cfg, *args, **kwargs):\n        super(MSDataLoader, self).__init__(\n            *args, **kwargs, num_workers=cfg.n_threads\n        )\n        self.scale = cfg.scale\n\n    def __iter__(self):\n        return _MSDataLoaderIter(self)\n\n"
  },
  {
    "path": "src/demo.sh",
    "content": "#!/bin/bash\n#Train x2\npython main.py --dir_data ../../ --n_GPUs 4 --rgb_range 1 --chunk_size 144 --n_hashes 4 --save_models --lr 1e-4 --decay 200-400-600-800 --epochs 1000 --chop --save_results --n_resblocks 32 --n_feats 256 --res_scale 0.1 --batch_size 16 --model NLSN --scale 2 --patch_size 96 --save NLSN_x2 --data_train DIV2K\n#Test x2\npython main.py --dir_data ../../ --model NLSN  --chunk_size 144 --data_test Set5+Set14+B100+Urban100 --n_hashes 4 --chop --save_results --rgb_range 1 --data_range 801-900 --scale 2 --n_feats 256 --n_resblocks 32 --res_scale 0.1  --pre_train model_x2.pt --test_only \n"
  },
  {
    "path": "src/loss/__init__.py",
    "content": "import os\r\nfrom importlib import import_module\r\n\r\nimport matplotlib\r\nmatplotlib.use('Agg')\r\nimport matplotlib.pyplot as plt\r\n\r\nimport numpy as np\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nclass Loss(nn.modules.loss._Loss):\r\n    def __init__(self, args, ckp):\r\n        super(Loss, self).__init__()\r\n        print('Preparing loss function:')\r\n\r\n        self.n_GPUs = args.n_GPUs\r\n        self.loss = []\r\n        self.loss_module = nn.ModuleList()\r\n        for loss in args.loss.split('+'):\r\n            weight, loss_type = loss.split('*')\r\n            if loss_type == 'MSE':\r\n                loss_function = nn.MSELoss()\r\n            elif loss_type == 'L1':\r\n                loss_function = nn.L1Loss()\r\n            elif loss_type.find('VGG') >= 0:\r\n                module = import_module('loss.vgg')\r\n                loss_function = getattr(module, 'VGG')(\r\n                    loss_type[3:],\r\n                    rgb_range=args.rgb_range\r\n                )\r\n            elif loss_type.find('GAN') >= 0:\r\n                module = import_module('loss.adversarial')\r\n                loss_function = getattr(module, 'Adversarial')(\r\n                    args,\r\n                    loss_type\r\n                )\r\n\r\n            self.loss.append({\r\n                'type': loss_type,\r\n                'weight': float(weight),\r\n                'function': loss_function}\r\n            )\r\n            if loss_type.find('GAN') >= 0:\r\n                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})\r\n\r\n        if len(self.loss) > 1:\r\n            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})\r\n\r\n        for l in self.loss:\r\n            if l['function'] is not None:\r\n                print('{:.3f} * {}'.format(l['weight'], l['type']))\r\n                self.loss_module.append(l['function'])\r\n\r\n        self.log = torch.Tensor()\r\n\r\n        device = torch.device('cpu' if args.cpu else 'cuda')\r\n        self.loss_module.to(device)\r\n        if args.precision == 'half': self.loss_module.half()\r\n        if not args.cpu and args.n_GPUs > 1:\r\n            self.loss_module = nn.DataParallel(self.loss_module,range(args.n_GPUs))\r\n\r\n        if args.load != '': self.load(ckp.dir, cpu=args.cpu)\r\n\r\n    def forward(self, sr, hr):\r\n        losses = []\r\n        for i, l in enumerate(self.loss):\r\n            if l['function'] is not None:\r\n                loss = l['function'](sr, hr)\r\n                effective_loss = l['weight'] * loss\r\n                losses.append(effective_loss)\r\n                self.log[-1, i] += effective_loss.item()\r\n            elif l['type'] == 'DIS':\r\n                self.log[-1, i] += self.loss[i - 1]['function'].loss\r\n\r\n        loss_sum = sum(losses)\r\n        if len(self.loss) > 1:\r\n            self.log[-1, -1] += loss_sum.item()\r\n\r\n        return loss_sum\r\n\r\n    def step(self):\r\n        for l in self.get_loss_module():\r\n            if hasattr(l, 'scheduler'):\r\n                l.scheduler.step()\r\n\r\n    def start_log(self):\r\n        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))\r\n\r\n    def end_log(self, n_batches):\r\n        self.log[-1].div_(n_batches)\r\n\r\n    def display_loss(self, batch):\r\n        n_samples = batch + 1\r\n        log = []\r\n        for l, c in zip(self.loss, self.log[-1]):\r\n            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))\r\n\r\n        return ''.join(log)\r\n\r\n    def plot_loss(self, apath, epoch):\r\n        axis = np.linspace(1, epoch, epoch)\r\n        for i, l in enumerate(self.loss):\r\n            label = '{} Loss'.format(l['type'])\r\n            fig = plt.figure()\r\n            plt.title(label)\r\n            plt.plot(axis, self.log[:, i].numpy(), label=label)\r\n            plt.legend()\r\n            plt.xlabel('Epochs')\r\n            plt.ylabel('Loss')\r\n            plt.grid(True)\r\n            plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))\r\n            plt.close(fig)\r\n\r\n    def get_loss_module(self):\r\n        if self.n_GPUs == 1:\r\n            return self.loss_module\r\n        else:\r\n            return self.loss_module.module\r\n\r\n    def save(self, apath):\r\n        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))\r\n        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))\r\n\r\n    def load(self, apath, cpu=False):\r\n        if cpu:\r\n            kwargs = {'map_location': lambda storage, loc: storage}\r\n        else:\r\n            kwargs = {}\r\n\r\n        self.load_state_dict(torch.load(\r\n            os.path.join(apath, 'loss.pt'),\r\n            **kwargs\r\n        ))\r\n        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))\r\n        for l in self.get_loss_module():\r\n            if hasattr(l, 'scheduler'):\r\n                for _ in range(len(self.log)): l.scheduler.step()\r\n\r\n"
  },
  {
    "path": "src/loss/__loss__.py",
    "content": ""
  },
  {
    "path": "src/loss/adversarial.py",
    "content": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nclass Adversarial(nn.Module):\n    def __init__(self, args, gan_type):\n        super(Adversarial, self).__init__()\n        self.gan_type = gan_type\n        self.gan_k = args.gan_k\n        self.dis = discriminator.Discriminator(args)\n        if gan_type == 'WGAN_GP':\n            # see https://arxiv.org/pdf/1704.00028.pdf pp.4\n            optim_dict = {\n                'optimizer': 'ADAM',\n                'betas': (0, 0.9),\n                'epsilon': 1e-8,\n                'lr': 1e-5,\n                'weight_decay': args.weight_decay,\n                'decay': args.decay,\n                'gamma': args.gamma\n            }\n            optim_args = SimpleNamespace(**optim_dict)\n        else:\n            optim_args = args\n\n        self.optimizer = utility.make_optimizer(optim_args, self.dis)\n\n    def forward(self, fake, real):\n        # updating discriminator...\n        self.loss = 0\n        fake_detach = fake.detach()     # do not backpropagate through G\n        for _ in range(self.gan_k):\n            self.optimizer.zero_grad()\n            # d: B x 1 tensor\n            d_fake = self.dis(fake_detach)\n            d_real = self.dis(real)\n            retain_graph = False\n            if self.gan_type == 'GAN':\n                loss_d = self.bce(d_real, d_fake)\n            elif self.gan_type.find('WGAN') >= 0:\n                loss_d = (d_fake - d_real).mean()\n                if self.gan_type.find('GP') >= 0:\n                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)\n                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)\n                    hat.requires_grad = True\n                    d_hat = self.dis(hat)\n                    gradients = torch.autograd.grad(\n                        outputs=d_hat.sum(), inputs=hat,\n                        retain_graph=True, create_graph=True, only_inputs=True\n                    )[0]\n                    gradients = gradients.view(gradients.size(0), -1)\n                    gradient_norm = gradients.norm(2, dim=1)\n                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()\n                    loss_d += gradient_penalty\n            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks\n            elif self.gan_type == 'RGAN':\n                better_real = d_real - d_fake.mean(dim=0, keepdim=True)\n                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)\n                loss_d = self.bce(better_real, better_fake)\n                retain_graph = True\n\n            # Discriminator update\n            self.loss += loss_d.item()\n            loss_d.backward(retain_graph=retain_graph)\n            self.optimizer.step()\n\n            if self.gan_type == 'WGAN':\n                for p in self.dis.parameters():\n                    p.data.clamp_(-1, 1)\n\n        self.loss /= self.gan_k\n\n        # updating generator...\n        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is\n        if self.gan_type == 'GAN':\n            label_real = torch.ones_like(d_fake_bp)\n            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)\n        elif self.gan_type.find('WGAN') >= 0:\n            loss_g = -d_fake_bp.mean()\n        elif self.gan_type == 'RGAN':\n            better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)\n            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)\n            loss_g = self.bce(better_fake, better_real)\n\n        # Generator loss\n        return loss_g\n    \n    def state_dict(self, *args, **kwargs):\n        state_discriminator = self.dis.state_dict(*args, **kwargs)\n        state_optimizer = self.optimizer.state_dict()\n\n        return dict(**state_discriminator, **state_optimizer)\n\n    def bce(self, real, fake):\n        label_real = torch.ones_like(real)\n        label_fake = torch.zeros_like(fake)\n        bce_real = F.binary_cross_entropy_with_logits(real, label_real)\n        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)\n        bce_loss = bce_real + bce_fake\n        return bce_loss\n               \n# Some references\n# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py\n# OR\n# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py\n"
  },
  {
    "path": "src/loss/demo.sh",
    "content": ""
  },
  {
    "path": "src/loss/discriminator.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    '''\n        output is not normalized\n    '''\n    def __init__(self, args):\n        super(Discriminator, self).__init__()\n\n        in_channels = args.n_colors\n        out_channels = 64\n        depth = 7\n\n        def _block(_in_channels, _out_channels, stride=1):\n            return nn.Sequential(\n                nn.Conv2d(\n                    _in_channels,\n                    _out_channels,\n                    3,\n                    padding=1,\n                    stride=stride,\n                    bias=False\n                ),\n                nn.BatchNorm2d(_out_channels),\n                nn.LeakyReLU(negative_slope=0.2, inplace=True)\n            )\n\n        m_features = [_block(in_channels, out_channels)]\n        for i in range(depth):\n            in_channels = out_channels\n            if i % 2 == 1:\n                stride = 1\n                out_channels *= 2\n            else:\n                stride = 2\n            m_features.append(_block(in_channels, out_channels, stride=stride))\n\n        patch_size = args.patch_size // (2**((depth + 1) // 2))\n        m_classifier = [\n            nn.Linear(out_channels * patch_size**2, 1024),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Linear(1024, 1)\n        ]\n\n        self.features = nn.Sequential(*m_features)\n        self.classifier = nn.Sequential(*m_classifier)\n\n    def forward(self, x):\n        features = self.features(x)\n        output = self.classifier(features.view(features.size(0), -1))\n\n        return output\n\n"
  },
  {
    "path": "src/loss/hash.py",
    "content": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nclass HASH(nn.Module):\n    def __init__(self):\n        super(HASH, self).__init__()\n        self.l1 = nn.L1Loss()\n    def forward(self, sr, qk, orders, hr, m=3):\n        #hash loss\n        qk = F.normalize(qk, p=2, dim=1, eps=5e-5)\n        N,C,H,W = qk.shape\n        qk = qk.view(N,C,H*W)\n        qk_t = qk.permute(0,2,1).contiguous()\n        similarity_map = F.relu(torch.matmul(qk_t, qk),inplace=True) #[N,H*W,H*W]\n        \n        orders = orders.unsqueeze(2).expand_as(similarity_map)#[N,H*W,H*W]\n        orders_t = torch.transpose(orders,1,2)\n        dist = torch.pow(orders-orders_t,2)\n        \n        ls = torch.mean(similarity_map*torch.log(torch.exp(dist+m)+1))\n        ld = torch.mean((1-similarity_map)*torch.log(torch.exp(-dist+m)+1))\n        loss = 0.005*(ls+ld)+self.l1(sr,hr) \n\n        return loss\n"
  },
  {
    "path": "src/loss/vgg.py",
    "content": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nclass VGG(nn.Module):\n    def __init__(self, conv_index, rgb_range=1):\n        super(VGG, self).__init__()\n        vgg_features = models.vgg19(pretrained=True).features\n        modules = [m for m in vgg_features]\n        if conv_index.find('22') >= 0:\n            self.vgg = nn.Sequential(*modules[:8])\n        elif conv_index.find('54') >= 0:\n            self.vgg = nn.Sequential(*modules[:35])\n\n        vgg_mean = (0.485, 0.456, 0.406)\n        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)\n        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)\n        for p in self.parameters():\n            p.requires_grad = False\n\n    def forward(self, sr, hr):\n        def _forward(x):\n            x = self.sub_mean(x)\n            x = self.vgg(x)\n            return x\n            \n        vgg_sr = _forward(sr)\n        with torch.no_grad():\n            vgg_hr = _forward(hr.detach())\n\n        loss = F.mse_loss(vgg_sr, vgg_hr)\n\n        return loss\n"
  },
  {
    "path": "src/main.py",
    "content": "import torch\n\nimport utility\nimport data\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\n\ntorch.manual_seed(args.seed)\ncheckpoint = utility.checkpoint(args)\n\ndef main():\n    global model\n    if args.data_test == ['video']:\n        from videotester import VideoTester\n        model = model.Model(args, checkpoint)\n        print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))\n        t = VideoTester(args, model, checkpoint)\n        t.test()\n    else:\n        if checkpoint.ok:\n            loader = data.Data(args)\n            _model = model.Model(args, checkpoint)\n            print('Total params: %.2fM' % (sum(p.numel() for p in _model.parameters())/1000000.0))\n            _loss = loss.Loss(args, checkpoint) if not args.test_only else None\n            t = Trainer(args, loader, _model, _loss, checkpoint)\n            while not t.terminate():\n                t.train()\n                t.test()\n\n            checkpoint.done()\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "src/model/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Sanghyun Son\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "src/model/README.md",
    "content": "# EDSR-PyTorch\n![](/figs/main.png)\n\nThis repository is an official PyTorch implementation of the paper **\"Enhanced Deep Residual Networks for Single Image Super-Resolution\"** from **CVPRW 2017, 2nd NTIRE**.\nYou can find the original code and more information from [here](https://github.com/LimBee/NTIRE2017).\n\nIf you find our work useful in your research or publication, please cite our work:\n\n[1] Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee, **\"Enhanced Deep Residual Networks for Single Image Super-Resolution,\"** <i>2nd NTIRE: New Trends in Image Restoration and Enhancement workshop and challenge on image super-resolution in conjunction with **CVPR 2017**. </i> [[PDF](http://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf)] [[arXiv](https://arxiv.org/abs/1707.02921)] [[Slide](https://cv.snu.ac.kr/research/EDSR/Presentation_v3(release).pptx)]\n```\n@InProceedings{Lim_2017_CVPR_Workshops,\n  author = {Lim, Bee and Son, Sanghyun and Kim, Heewon and Nah, Seungjun and Lee, Kyoung Mu},\n  title = {Enhanced Deep Residual Networks for Single Image Super-Resolution},\n  booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},\n  month = {July},\n  year = {2017}\n}\n```\nWe provide scripts for reproducing all the results from our paper. You can train your own model from scratch, or use pre-trained model to enlarge your images.\n\n**Differences between Torch version**\n* Codes are much more compact. (Removed all unnecessary parts.)\n* Models are smaller. (About half.)\n* Slightly better performances.\n* Training and evaluation requires less memory.\n* Python-based.\n\n## Dependencies\n* Python 3.6\n* PyTorch >= 0.4.0\n* numpy\n* skimage\n* **imageio**\n* matplotlib\n* tqdm\n\n**Recent updates**\n\n* July 22, 2018\n  * Thanks for recent commits that contains RDN and RCAN. Please see ``code/demo.sh`` to train/test those models.\n  * Now the dataloader is much stable than the previous version. Please erase ``DIV2K/bin`` folder that is created before this commit. Also, please avoid to use ``--ext bin`` argument. Our code will automatically pre-decode png images before training. If you do not have enough spaces(~10GB) in your disk, we recommend ``--ext img``(But SLOW!).\n\n\n## Code\nClone this repository into any place you want.\n```bash\ngit clone https://github.com/thstkdgus35/EDSR-PyTorch\ncd EDSR-PyTorch\n```\n\n## Quick start (Demo)\nYou can test our super-resolution algorithm with your own images. Place your images in ``test`` folder. (like ``test/<your_image>``) We support **png** and **jpeg** files.\n\nRun the script in ``src`` folder. Before you run the demo, please uncomment the appropriate line in ```demo.sh``` that you want to execute.\n```bash\ncd src       # You are now in */EDSR-PyTorch/src\nsh demo.sh\n```\n\nYou can find the result images from ```experiment/test/results``` folder.\n\n| Model | Scale | File name (.pt) | Parameters | ****PSNR** |\n|  ---  |  ---  | ---       | ---        | ---  |\n| **EDSR** | 2 | EDSR_baseline_x2 | 1.37 M | 34.61 dB |\n| | | *EDSR_x2 | 40.7 M | 35.03 dB |\n| | 3 | EDSR_baseline_x3 | 1.55 M | 30.92 dB |\n| | | *EDSR_x3 | 43.7 M | 31.26 dB |\n| | 4 | EDSR_baseline_x4 | 1.52 M | 28.95 dB |\n| | | *EDSR_x4 | 43.1 M | 29.25 dB |\n| **MDSR** | 2 | MDSR_baseline | 3.23 M | 34.63 dB |\n| | | *MDSR | 7.95 M| 34.92 dB |\n| | 3 | MDSR_baseline | | 30.94 dB |\n| | | *MDSR | | 31.22 dB |\n| | 4 | MDSR_baseline | | 28.97 dB |\n| | | *MDSR | | 29.24 dB |\n\n*Baseline models are in ``experiment/model``. Please download our final models from [here](https://cv.snu.ac.kr/research/EDSR/model_pytorch.tar) (542MB)\n**We measured PSNR using DIV2K 0801 ~ 0900, RGB channels, without self-ensemble. (scale + 2) pixels from the image boundary are ignored.\n\nYou can evaluate your models with widely-used benchmark datasets:\n\n[Set5 - Bevilacqua et al. BMVC 2012](http://people.rennes.inria.fr/Aline.Roumy/results/SR_BMVC12.html),\n\n[Set14 - Zeyde et al. LNCS 2010](https://sites.google.com/site/romanzeyde/research-interests),\n\n[B100 - Martin et al. ICCV 2001](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/),\n\n[Urban100 - Huang et al. CVPR 2015](https://sites.google.com/site/jbhuang0604/publications/struct_sr).\n\nFor these datasets, we first convert the result images to YCbCr color space and evaluate PSNR on the Y channel only. You can download [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (250MB). Set ``--dir_data <where_benchmark_folder_located>`` to evaluate the EDSR and MDSR with the benchmarks.\n\n## How to train EDSR and MDSR\nWe used [DIV2K](http://www.vision.ee.ethz.ch/%7Etimofter/publications/Agustsson-CVPRW-2017.pdf) dataset to train our model. Please download it from [here](https://cv.snu.ac.kr/research/EDSR/DIV2K.tar) (7.1GB).\n\nUnpack the tar file to any place you want. Then, change the ```dir_data``` argument in ```src/option.py``` to the place where DIV2K images are located.\n\nWe recommend you to pre-process the images before training. This step will decode all **png** files and save them as binaries. Use ``--ext sep_reset`` argument on your first run. You can skip the decoding part and use saved binaries with ``--ext sep`` argument.\n\nIf you have enough RAM (>= 32GB), you can use ``--ext bin`` argument to pack all DIV2K images in one binary file.\n\nYou can train EDSR and MDSR by yourself. All scripts are provided in the ``src/demo.sh``. Note that EDSR (x3, x4) requires pre-trained EDSR (x2). You can ignore this constraint by removing ```--pre_train <x2 model>``` argument.\n\n```bash\ncd src       # You are now in */EDSR-PyTorch/src\nsh demo.sh\n```\n\n**Update log**\n* Jan 04, 2018\n  * Many parts are re-written. You cannot use previous scripts and models directly.\n  * Pre-trained MDSR is temporarily disabled.\n  * Training details are included.\n\n* Jan 09, 2018\n  * Missing files are included (```src/data/MyImage.py```).\n  * Some links are fixed.\n\n* Jan 16, 2018\n  * Memory efficient forward function is implemented.\n  * Add --chop_forward argument to your script to enable it.\n  * Basically, this function first split a large image to small patches. Those images are merged after super-resolution. I checked this function with 12GB memory, 4000 x 2000 input image in scale 4. (Therefore, the output will be 16000 x 8000.)\n\n* Feb 21, 2018\n  * Fixed the problem when loading pre-trained multi-gpu model.\n  * Added pre-trained scale 2 baseline model.\n  * This code now only saves the best-performing model by default. For MDSR, 'the best' can be ambiguous. Use --save_models argument to save all the intermediate models.\n  * PyTorch 0.3.1 changed their implementation of DataLoader function. Therefore, I also changed my implementation of MSDataLoader. You can find it on feature/dataloader branch.\n\n* Feb 23, 2018\n  * Now PyTorch 0.3.1 is default. Use legacy/0.3.0 branch if you use the old version.\n   \n  * With a new ``src/data/DIV2K.py`` code, one can easily create new data class for super-resolution.\n  * New binary data pack. (Please remove the ``DIV2K_decoded`` folder from your dataset if you have.)\n  * With ``--ext bin``, this code will automatically generates and saves the binary data pack that corresponds to previous ``DIV2K_decoded``. (This requires huge RAM (~45GB, Swap can be used.), so please be careful.)\n  * If you cannot make the binary pack, just use the default setting (``--ext img``).\n\n  * Fixed a bug that PSNR in the log and PSNR calculated from the saved images does not match.\n  * Now saved images have better quality! (PSNR is ~0.1dB higher than the original code.)\n  * Added performance comparison between Torch7 model and PyTorch models.\n\n* Mar 5, 2018\n  * All baseline models are uploaded.\n  * Now supports half-precision at test time. Use ``--precision half``  to enable it. This does not degrade the output images.\n\n* Mar 11, 2018\n  * Fixed some typos in the code and script.\n  * Now --ext img is default setting. Although we recommend you to use --ext bin when training, please use --ext img when you use --test_only.\n  * Skip_batch operation is implemented. Use --skip_threshold argument to skip the batch that you want to ignore. Although this function is not exactly same with that of Torch7 version, it will work as you expected.\n\n* Mar 20, 2018\n  * Use ``--ext sep_reset`` to pre-decode large png files. Those decoded files will be saved to the same directory with DIV2K png files. After the first run, you can use ``--ext sep`` to save time.\n  * Now supports various benchmark datasets. For example, try ``--data_test Set5`` to test your model on the Set5 images.\n  * Changed the behavior of skip_batch.\n\n* Mar 29, 2018\n  * We now provide all models from our paper.\n  * We also provide ``MDSR_baseline_jpeg`` model that suppresses JPEG artifacts in original low-resolution image. Please use it if you have any trouble.\n  * ``MyImage`` dataset is changed to ``Demo`` dataset. Also, it works more efficient than before.\n  * Some codes and script are re-written.\n\n* Apr 9, 2018\n  * VGG and Adversarial loss is implemented based on [SRGAN](http://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf). [WGAN](https://arxiv.org/abs/1701.07875) and [gradient penalty](https://arxiv.org/abs/1704.00028) are also implemented, but they are not tested yet.\n  * Many codes are refactored. If there exists a bug, please report it.\n  * [D-DBPN](https://arxiv.org/abs/1803.02735) is implemented. Default setting is D-DBPN-L.\n\n* Apr 26, 2018\n  * Compatible with PyTorch 0.4.0\n  * Please use the legacy/0.3.1 branch if you are using the old version of PyTorch.\n  * Minor bug fixes\n"
  },
  {
    "path": "src/model/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\n\nclass Model(nn.Module):\n    def __init__(self, args, ckp):\n        super(Model, self).__init__()\n        print('Making model...')\n\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.self_ensemble = args.self_ensemble\n        self.chop = args.chop\n        self.precision = args.precision\n        self.cpu = args.cpu\n        self.device = torch.device('cpu' if args.cpu else 'cuda')\n        self.n_GPUs = args.n_GPUs\n        self.save_models = args.save_models\n\n        module = import_module('model.' + args.model.lower())\n        self.model = module.make_model(args).to(self.device)\n        if args.precision == 'half': self.model.half()\n\n        if not args.cpu and args.n_GPUs > 1:\n            self.model = nn.DataParallel(self.model, range(args.n_GPUs))\n\n        self.load(\n            ckp.dir,\n            pre_train=args.pre_train,\n            resume=args.resume,\n            cpu=args.cpu\n        )\n        print(self.model, file=ckp.log_file)\n\n    def forward(self, x, idx_scale):\n        self.idx_scale = idx_scale\n        target = self.get_model()\n        if hasattr(target, 'set_scale'):\n            target.set_scale(idx_scale)\n\n        if self.self_ensemble and not self.training:\n            if self.chop:\n                forward_function = self.forward_chop\n            else:\n                forward_function = self.model.forward\n\n            return self.forward_x8(x, forward_function)\n        elif self.chop and not self.training:\n            return self.forward_chop(x)\n        else:\n            return self.model(x)\n\n    def get_model(self):\n        if self.n_GPUs == 1:\n            return self.model\n        else:\n            return self.model.module\n\n    def state_dict(self, **kwargs):\n        target = self.get_model()\n        return target.state_dict(**kwargs)\n\n    def save(self, apath, epoch, is_best=False):\n        target = self.get_model()\n        torch.save(\n            target.state_dict(), \n            os.path.join(apath, 'model_latest.pt')\n        )\n        if is_best:\n            torch.save(\n                target.state_dict(),\n                os.path.join(apath, 'model_best.pt')\n            )\n        \n        if self.save_models:\n            torch.save(\n                target.state_dict(),\n                os.path.join(apath, 'model_{}.pt'.format(epoch))\n            )\n\n    def load(self, apath, pre_train='.', resume=-1, cpu=False):\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n        else:\n            kwargs = {}\n\n        if resume == -1:\n            self.get_model().load_state_dict(\n                torch.load(\n                    os.path.join(apath, 'model_latest.pt'),\n                    **kwargs\n                ),\n                strict=False\n            )\n        elif resume == 0:\n            if pre_train != '.':\n                print('Loading model from {}'.format(pre_train))\n                self.get_model().load_state_dict(\n                    torch.load(pre_train, **kwargs),\n                    strict=False\n                )\n        else:\n            self.get_model().load_state_dict(\n                torch.load(\n                    os.path.join(apath, 'model', 'model_{}.pt'.format(resume)),\n                    **kwargs\n                ),\n                strict=False\n            )\n\n    def forward_chop(self, x, shave=10, min_size=120000):\n        scale = self.scale[self.idx_scale]\n        n_GPUs = min(self.n_GPUs, 4)\n        b, c, h, w = x.size()\n        h_half, w_half = h // 2, w // 2\n        h_size, w_size = h_half + shave, w_half + shave\n        h_size +=4-h_size%4\n        w_size +=8-w_size%8\n        \n        lr_list = [\n            x[:, :, 0:h_size, 0:w_size],\n            x[:, :, 0:h_size, (w - w_size):w],\n            x[:, :, (h - h_size):h, 0:w_size],\n            x[:, :, (h - h_size):h, (w - w_size):w]]\n\n        if w_size * h_size < min_size:\n            sr_list = []\n            for i in range(0, 4, n_GPUs):\n                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)\n                sr_batch = self.model(lr_batch)\n                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))\n        else:\n            sr_list = [\n                self.forward_chop(patch, shave=shave, min_size=min_size) \\\n                for patch in lr_list\n            ]\n\n        h, w = scale * h, scale * w\n        h_half, w_half = scale * h_half, scale * w_half\n        h_size, w_size = scale * h_size, scale * w_size\n        shave *= scale\n\n        output = x.new(b, c, h, w)\n        output[:, :, 0:h_half, 0:w_half] \\\n            = sr_list[0][:, :, 0:h_half, 0:w_half]\n        output[:, :, 0:h_half, w_half:w] \\\n            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]\n        output[:, :, h_half:h, 0:w_half] \\\n            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]\n        output[:, :, h_half:h, w_half:w] \\\n            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]\n\n        return output\n\n    def forward_x8(self, x, forward_function):\n        def _transform(v, op):\n            if self.precision != 'single': v = v.float()\n\n            v2np = v.data.cpu().numpy()\n            if op == 'v':\n                tfnp = v2np[:, :, :, ::-1].copy()\n            elif op == 'h':\n                tfnp = v2np[:, :, ::-1, :].copy()\n            elif op == 't':\n                tfnp = v2np.transpose((0, 1, 3, 2)).copy()\n\n            ret = torch.Tensor(tfnp).to(self.device)\n            if self.precision == 'half': ret = ret.half()\n\n            return ret\n\n        lr_list = [x]\n        for tf in 'v', 'h', 't':\n            lr_list.extend([_transform(t, tf) for t in lr_list])\n\n        sr_list = [forward_function(aug) for aug in lr_list]\n        for i in range(len(sr_list)):\n            if i > 3:\n                sr_list[i] = _transform(sr_list[i], 't')\n            if i % 4 > 1:\n                sr_list[i] = _transform(sr_list[i], 'h')\n            if (i % 4) % 2 == 1:\n                sr_list[i] = _transform(sr_list[i], 'v')\n\n        output_cat = torch.cat(sr_list, dim=0)\n        output = output_cat.mean(dim=0, keepdim=True)\n\n        return output\n\n"
  },
  {
    "path": "src/model/attention.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom model import common\n\nclass NonLocalSparseAttention(nn.Module):\n    def __init__( self, n_hashes=4, channels=64, k_size=3, reduction=4, chunk_size=144, conv=common.default_conv, res_scale=1):\n        super(NonLocalSparseAttention,self).__init__()\n        self.chunk_size = chunk_size\n        self.n_hashes = n_hashes\n        self.reduction = reduction\n        self.res_scale = res_scale\n        self.conv_match = common.BasicBlock(conv, channels, channels//reduction, k_size, bn=False, act=None)\n        self.conv_assembly = common.BasicBlock(conv, channels, channels, 1, bn=False, act=None)\n\n    def LSH(self, hash_buckets, x):\n        #x: [N,H*W,C]\n        N = x.shape[0]\n        device = x.device\n        \n        #generate random rotation matrix\n        rotations_shape = (1, x.shape[-1], self.n_hashes, hash_buckets//2) #[1,C,n_hashes,hash_buckets//2]\n        random_rotations = torch.randn(rotations_shape, dtype=x.dtype, device=device).expand(N, -1, -1, -1) #[N, C, n_hashes, hash_buckets//2]\n        \n        #locality sensitive hashing\n        rotated_vecs = torch.einsum('btf,bfhi->bhti', x, random_rotations) #[N, n_hashes, H*W, hash_buckets//2]\n        rotated_vecs = torch.cat([rotated_vecs, -rotated_vecs], dim=-1) #[N, n_hashes, H*W, hash_buckets]\n        \n        #get hash codes\n        hash_codes = torch.argmax(rotated_vecs, dim=-1) #[N,n_hashes,H*W]\n        \n        #add offsets to avoid hash codes overlapping between hash rounds \n        offsets = torch.arange(self.n_hashes, device=device) \n        offsets = torch.reshape(offsets * hash_buckets, (1, -1, 1))\n        hash_codes = torch.reshape(hash_codes + offsets, (N, -1,)) #[N,n_hashes*H*W]\n     \n        return hash_codes \n    \n    def add_adjacent_buckets(self, x):\n            x_extra_back = torch.cat([x[:,:,-1:, ...], x[:,:,:-1, ...]], dim=2)\n            x_extra_forward = torch.cat([x[:,:,1:, ...], x[:,:,:1,...]], dim=2)\n            return torch.cat([x, x_extra_back,x_extra_forward], dim=3)\n\n    def forward(self, input):\n        \n        N,_,H,W = input.shape\n        x_embed = self.conv_match(input).view(N,-1,H*W).contiguous().permute(0,2,1)\n        y_embed = self.conv_assembly(input).view(N,-1,H*W).contiguous().permute(0,2,1)\n        L,C = x_embed.shape[-2:]\n\n        #number of hash buckets/hash bits\n        hash_buckets = min(L//self.chunk_size + (L//self.chunk_size)%2, 128)\n        \n        #get assigned hash codes/bucket number         \n        hash_codes = self.LSH(hash_buckets, x_embed) #[N,n_hashes*H*W]\n        hash_codes = hash_codes.detach()\n\n        #group elements with same hash code by sorting\n        _, indices = hash_codes.sort(dim=-1) #[N,n_hashes*H*W]\n        _, undo_sort = indices.sort(dim=-1) #undo_sort to recover original order\n        mod_indices = (indices % L) #now range from (0->H*W)\n        x_embed_sorted = common.batched_index_select(x_embed, mod_indices) #[N,n_hashes*H*W,C]\n        y_embed_sorted = common.batched_index_select(y_embed, mod_indices) #[N,n_hashes*H*W,C]\n        \n        #pad the embedding if it cannot be divided by chunk_size\n        padding = self.chunk_size - L%self.chunk_size if L%self.chunk_size!=0 else 0\n        x_att_buckets = torch.reshape(x_embed_sorted, (N, self.n_hashes,-1, C)) #[N, n_hashes, H*W,C]\n        y_att_buckets = torch.reshape(y_embed_sorted, (N, self.n_hashes,-1, C*self.reduction)) \n        if padding:\n            pad_x = x_att_buckets[:,:,-padding:,:].clone()\n            pad_y = y_att_buckets[:,:,-padding:,:].clone()\n            x_att_buckets = torch.cat([x_att_buckets,pad_x],dim=2)\n            y_att_buckets = torch.cat([y_att_buckets,pad_y],dim=2)\n        \n        x_att_buckets = torch.reshape(x_att_buckets,(N,self.n_hashes,-1,self.chunk_size,C)) #[N, n_hashes, num_chunks, chunk_size, C]\n        y_att_buckets = torch.reshape(y_att_buckets,(N,self.n_hashes,-1,self.chunk_size, C*self.reduction))\n        \n        x_match = F.normalize(x_att_buckets, p=2, dim=-1,eps=5e-5)\n\n        #allow attend to adjacent buckets\n        x_match = self.add_adjacent_buckets(x_match)\n        y_att_buckets = self.add_adjacent_buckets(y_att_buckets)\n        \n        #unormalized attention score\n        raw_score = torch.einsum('bhkie,bhkje->bhkij', x_att_buckets, x_match) #[N, n_hashes, num_chunks, chunk_size, chunk_size*3]\n        \n        #softmax\n        bucket_score = torch.logsumexp(raw_score, dim=-1, keepdim=True)\n        score = torch.exp(raw_score - bucket_score) #(after softmax)\n        bucket_score = torch.reshape(bucket_score,[N,self.n_hashes,-1])\n        \n        #attention\n        ret = torch.einsum('bukij,bukje->bukie', score, y_att_buckets) #[N, n_hashes, num_chunks, chunk_size, C]\n        ret = torch.reshape(ret,(N,self.n_hashes,-1,C*self.reduction))\n        \n        #if padded, then remove extra elements\n        if padding:\n            ret = ret[:,:,:-padding,:].clone()\n            bucket_score = bucket_score[:,:,:-padding].clone()\n         \n        #recover the original order\n        ret = torch.reshape(ret, (N, -1, C*self.reduction)) #[N, n_hashes*H*W,C]\n        bucket_score = torch.reshape(bucket_score, (N, -1,)) #[N,n_hashes*H*W]\n        ret = common.batched_index_select(ret, undo_sort)#[N, n_hashes*H*W,C]\n        bucket_score = bucket_score.gather(1, undo_sort)#[N,n_hashes*H*W]\n        \n        #weighted sum multi-round attention\n        ret = torch.reshape(ret, (N, self.n_hashes, L, C*self.reduction)) #[N, n_hashes*H*W,C]\n        bucket_score = torch.reshape(bucket_score, (N, self.n_hashes, L, 1))\n        probs = nn.functional.softmax(bucket_score,dim=1)\n        ret = torch.sum(ret * probs, dim=1)\n        \n        ret = ret.permute(0,2,1).view(N,-1,H,W).contiguous()*self.res_scale+input\n        return ret\n\n\nclass NonLocalAttention(nn.Module):\n    def __init__(self, channel=128, reduction=2, ksize=1, scale=3, stride=1, softmax_scale=10, average=True, res_scale=1,conv=common.default_conv):\n        super(NonLocalAttention, self).__init__()\n        self.res_scale = res_scale\n        self.conv_match1 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act=nn.PReLU())\n        self.conv_match2 = common.BasicBlock(conv, channel, channel//reduction, 1, bn=False, act = nn.PReLU())\n        self.conv_assembly = common.BasicBlock(conv, channel, channel, 1,bn=False, act=nn.PReLU())\n        \n    def forward(self, input):\n        x_embed_1 = self.conv_match1(input)\n        x_embed_2 = self.conv_match2(input)\n        x_assembly = self.conv_assembly(input)\n\n        N,C,H,W = x_embed_1.shape\n        x_embed_1 = x_embed_1.permute(0,2,3,1).view((N,H*W,C))\n        x_embed_2 = x_embed_2.view(N,C,H*W)\n        score = torch.matmul(x_embed_1, x_embed_2)\n        score = F.softmax(score, dim=2)\n        x_assembly = x_assembly.view(N,-1,H*W).permute(0,2,1)\n        x_final = torch.matmul(score, x_assembly)\n        return x_final.permute(0,2,1).view(N,-1,H,W)+self.res_scale*input\n"
  },
  {
    "path": "src/model/common.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef batched_index_select(values, indices):\n    last_dim = values.shape[-1]\n    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))\n    \ndef default_conv(in_channels, out_channels, kernel_size,stride=1, bias=True):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size//2),stride=stride, bias=bias)\n\nclass MeanShift(nn.Conv2d):\n    def __init__(\n        self, rgb_range,\n        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):\n\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std\n        for p in self.parameters():\n            p.requires_grad = False\n\nclass BasicBlock(nn.Sequential):\n    def __init__(\n        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=True,\n        bn=False, act=nn.PReLU()):\n\n        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]\n        if bn:\n            m.append(nn.BatchNorm2d(out_channels))\n        if act is not None:\n            m.append(act)\n\n        super(BasicBlock, self).__init__(*m)\n\nclass ResBlock(nn.Module):\n    def __init__(\n        self, conv, n_feats, kernel_size,\n        bias=True, bn=False, act=nn.PReLU(), res_scale=1):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if i == 0:\n                m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\nclass Upsampler(nn.Sequential):\n    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):\n\n        m = []\n        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(conv(n_feats, 4 * n_feats, 3, bias=bias))\n                m.append(nn.PixelShuffle(2))\n                if bn:\n                    m.append(nn.BatchNorm2d(n_feats))\n                if act == 'relu':\n                    m.append(nn.ReLU(True))\n                elif act == 'prelu':\n                    m.append(nn.PReLU(n_feats))\n\n        elif scale == 3:\n            m.append(conv(n_feats, 9 * n_feats, 3, bias=bias))\n            m.append(nn.PixelShuffle(3))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if act == 'relu':\n                m.append(nn.ReLU(True))\n            elif act == 'prelu':\n                m.append(nn.PReLU(n_feats))\n        else:\n            raise NotImplementedError\n\n        super(Upsampler, self).__init__(*m)\n\n"
  },
  {
    "path": "src/model/ddbpn.py",
    "content": "# Deep Back-Projection Networks For Super-Resolution\n# https://arxiv.org/abs/1803.02735\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return DDBPN(args)\n\ndef projection_conv(in_channels, out_channels, scale, up=True):\n    kernel_size, stride, padding = {\n        2: (6, 2, 2),\n        4: (8, 4, 2),\n        8: (12, 8, 2)\n    }[scale]\n    if up:\n        conv_f = nn.ConvTranspose2d\n    else:\n        conv_f = nn.Conv2d\n\n    return conv_f(\n        in_channels, out_channels, kernel_size,\n        stride=stride, padding=padding\n    )\n\nclass DenseProjection(nn.Module):\n    def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):\n        super(DenseProjection, self).__init__()\n        if bottleneck:\n            self.bottleneck = nn.Sequential(*[\n                nn.Conv2d(in_channels, nr, 1),\n                nn.PReLU(nr)\n            ])\n            inter_channels = nr\n        else:\n            self.bottleneck = None\n            inter_channels = in_channels\n\n        self.conv_1 = nn.Sequential(*[\n            projection_conv(inter_channels, nr, scale, up),\n            nn.PReLU(nr)\n        ])\n        self.conv_2 = nn.Sequential(*[\n            projection_conv(nr, inter_channels, scale, not up),\n            nn.PReLU(inter_channels)\n        ])\n        self.conv_3 = nn.Sequential(*[\n            projection_conv(inter_channels, nr, scale, up),\n            nn.PReLU(nr)\n        ])\n\n    def forward(self, x):\n        if self.bottleneck is not None:\n            x = self.bottleneck(x)\n\n        a_0 = self.conv_1(x)\n        b_0 = self.conv_2(a_0)\n        e = b_0.sub(x)\n        a_1 = self.conv_3(e)\n\n        out = a_0.add(a_1)\n\n        return out\n\nclass DDBPN(nn.Module):\n    def __init__(self, args):\n        super(DDBPN, self).__init__()\n        scale = args.scale[0]\n\n        n0 = 128\n        nr = 32\n        self.depth = 6\n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        initial = [\n            nn.Conv2d(args.n_colors, n0, 3, padding=1),\n            nn.PReLU(n0),\n            nn.Conv2d(n0, nr, 1),\n            nn.PReLU(nr)\n        ]\n        self.initial = nn.Sequential(*initial)\n\n        self.upmodules = nn.ModuleList()\n        self.downmodules = nn.ModuleList()\n        channels = nr\n        for i in range(self.depth):\n            self.upmodules.append(\n                DenseProjection(channels, nr, scale, True, i > 1)\n            )\n            if i != 0:\n                channels += nr\n        \n        channels = nr\n        for i in range(self.depth - 1):\n            self.downmodules.append(\n                DenseProjection(channels, nr, scale, False, i != 0)\n            )\n            channels += nr\n\n        reconstruction = [\n            nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) \n        ]\n        self.reconstruction = nn.Sequential(*reconstruction)\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.initial(x)\n\n        h_list = []\n        l_list = []\n        for i in range(self.depth - 1):\n            if i == 0:\n                l = x\n            else:\n                l = torch.cat(l_list, dim=1)\n            h_list.append(self.upmodules[i](l))\n            l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))\n        \n        h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))\n        out = self.reconstruction(torch.cat(h_list, dim=1))\n        out = self.add_mean(out)\n\n        return out\n\n"
  },
  {
    "path": "src/model/edsr.py",
    "content": "from model import common\nfrom model import attention\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n    if args.dilation:\n        from model import dilated\n        return EDSR(args, dilated.dilated_conv)\n    else:\n        return EDSR(args)\n\nclass EDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(EDSR, self).__init__()\n\n        n_resblock = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3 \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        #self.msa = attention.PyramidAttention(channel=256, reduction=8,res_scale=args.res_scale);         \n        # define head module\n        m_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        m_body = [\n            common.ResBlock(\n                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale\n            ) for _ in range(n_resblock//2)\n        ]\n        #m_body.append(self.msa)\n        for _ in range(n_resblock//2):\n            m_body.append( common.ResBlock(\n                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale\n            ))\n        m_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        m_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            nn.Conv2d(\n                n_feats, args.n_colors, kernel_size,\n                padding=(kernel_size//2)\n            )\n        ]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*m_head)\n        self.body = nn.Sequential(*m_body)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=True):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') == -1:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n"
  },
  {
    "path": "src/model/mdsr.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n    return MDSR(args)\n\nclass MDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(MDSR, self).__init__()\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        self.scale_idx = 0\n\n        act = nn.ReLU(True)\n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n\n        m_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        self.pre_process = nn.ModuleList([\n            nn.Sequential(\n                common.ResBlock(conv, n_feats, 5, act=act),\n                common.ResBlock(conv, n_feats, 5, act=act)\n            ) for _ in args.scale\n        ])\n\n        m_body = [\n            common.ResBlock(\n                conv, n_feats, kernel_size, act=act\n            ) for _ in range(n_resblocks)\n        ]\n        m_body.append(conv(n_feats, n_feats, kernel_size))\n\n        self.upsample = nn.ModuleList([\n            common.Upsampler(\n                conv, s, n_feats, act=False\n            ) for s in args.scale\n        ])\n\n        m_tail = [conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*m_head)\n        self.body = nn.Sequential(*m_body)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        x = self.pre_process[self.scale_idx](x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.upsample[self.scale_idx](res)\n        x = self.tail(x)\n        x = self.add_mean(x)\n\n        return x\n\n    def set_scale(self, scale_idx):\n        self.scale_idx = scale_idx\n\n"
  },
  {
    "path": "src/model/mssr.py",
    "content": "from model import common\nimport torch.nn as nn\nimport torch\nfrom model.attention import ContextualAttention,NonLocalAttention\ndef make_model(args, parent=False):\n    return MSSR(args)\n\nclass MultisourceProjection(nn.Module):\n    def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):\n        super(MultisourceProjection, self).__init__()\n        self.up_attention = ContextualAttention(scale=2)\n        self.down_attention = NonLocalAttention()\n        self.upsample = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])\n        self.encoder = common.ResBlock(conv, in_channel, kernel_size, act=nn.PReLU(), res_scale=1)\n    \n    def forward(self,x):\n        down_map = self.upsample(self.down_attention(x))\n        up_map = self.up_attention(x)\n\n        err = self.encoder(up_map-down_map)\n        final_map = down_map + err\n        \n        return final_map\n\nclass RecurrentProjection(nn.Module):\n    def __init__(self, in_channel,kernel_size = 3, conv=common.default_conv):\n        super(RecurrentProjection, self).__init__()\n        self.multi_source_projection_1 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)\n        self.multi_source_projection_2 = MultisourceProjection(in_channel,kernel_size=kernel_size,conv=conv)\n        self.down_sample_1 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])\n\t#self.down_sample_2 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])\n        self.down_sample_3 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])\n        self.down_sample_4 = nn.Sequential(*[nn.Conv2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])\n        self.error_encode_1 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,6,stride=2,padding=2),nn.PReLU()])\n        self.error_encode_2 = nn.Sequential(*[nn.ConvTranspose2d(in_channel,in_channel,8,stride=4,padding=2),nn.PReLU()])\n        self.post_conv = common.BasicBlock(conv,in_channel,in_channel,kernel_size,stride=1,bias=True,act=nn.PReLU())\n\n\n    def forward(self, x):\n        x_up = self.multi_source_projection_1(x)\n\n        x_down = self.down_sample_1(x_up)\n        error_up = self.error_encode_1(x-x_down)\n        h_estimate_1 = x_up + error_up\n\t\n        x_up_2 = self.multi_source_projection_2(h_estimate_1)\n        x_down_2 = self.down_sample_3(x_up_2)\n        error_up_2 = self.error_encode_2(x-x_down_2)\n        h_estimate_2 = x_up_2 + error_up_2\n        x_final = self.post_conv(self.down_sample_4(h_estimate_2))\n\n        return x_final, h_estimate_2\n        \n\n        \n\n\nclass MSSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(MSSR, self).__init__()\n\n        #n_convblock = args.n_convblocks\n        n_feats = args.n_feats\n        self.depth = args.depth\n        kernel_size = 3 \n        scale = args.scale[0]\n        \n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        m_head = [common.BasicBlock(conv, args.n_colors, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU()),\n        common.BasicBlock(conv,n_feats, n_feats, kernel_size,stride=1,bias=True,bn=False,act=nn.PReLU())]\n\n        # define multiple reconstruction module\n        \n        self.body = RecurrentProjection(n_feats)\n\n\n        # define tail module\n        m_tail = [\n            nn.Conv2d(\n                n_feats*self.depth, args.n_colors, kernel_size,\n                padding=(kernel_size//2)\n            )\n        ]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*m_head)\n        self.tail = nn.Sequential(*m_tail)\n    def forward(self,input):\n        x = self.sub_mean(input)\n        x = self.head(x)\n        bag = []\n        for i in range(self.depth):\n            x, h_estimate = self.body(x)\n            bag.append(h_estimate)\n        h_feature = torch.cat(bag,dim=1)\n        h_final = self.tail(h_feature)\n        \n        return self.add_mean(h_final)\n"
  },
  {
    "path": "src/model/nlsn.py",
    "content": "from model import common\nfrom model import attention\nimport torch.nn as nn\n\ndef make_model(args, parent=False):\n    if args.dilation:\n        from model import dilated\n        return NLSN(args, dilated.dilated_conv)\n    else:\n        return NLSN(args)\n\n\nclass NLSN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(NLSN, self).__init__()\n\n        n_resblock = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3 \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        m_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        m_body = [attention.NonLocalSparseAttention(\n            channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale)]         \n\n        for i in range(n_resblock):\n            m_body.append( common.ResBlock(\n                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale\n            ))\n            if (i+1)%8==0:\n                m_body.append(attention.NonLocalSparseAttention(\n                    channels=n_feats, chunk_size=args.chunk_size, n_hashes=args.n_hashes, reduction=4, res_scale=args.res_scale))\n        m_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        m_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            nn.Conv2d(\n                n_feats, args.n_colors, kernel_size,\n                padding=(kernel_size//2)\n            )\n        ]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*m_head)\n        self.body = nn.Sequential(*m_body)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=True):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') == -1:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n"
  },
  {
    "path": "src/model/rcan.py",
    "content": "## ECCV-2018-Image Super-Resolution Using Very Deep Residual Channel Attention Networks\n## https://arxiv.org/abs/1807.02758\nfrom model import common\n\nimport torch.nn as nn\nimport torch\ndef make_model(args, parent=False):\n    return RCAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        #self.a = torch.nn.Parameter(torch.Tensor([0]))\n        #self.a.requires_grad=True\n        \n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Channel Attention Network (RCAN)\nclass RCAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(RCAN, self).__init__()\n        self.a = nn.Parameter(torch.Tensor([0]))\n        self.a.requires_grad=True\n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('msa') or name.find('a') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('msa') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n"
  },
  {
    "path": "src/model/rdn.py",
    "content": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return RDN(args)\n\nclass RDB_Conv(nn.Module):\n    def __init__(self, inChannels, growRate, kSize=3):\n        super(RDB_Conv, self).__init__()\n        Cin = inChannels\n        G  = growRate\n        self.conv = nn.Sequential(*[\n            nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),\n            nn.ReLU()\n        ])\n\n    def forward(self, x):\n        out = self.conv(x)\n        return torch.cat((x, out), 1)\n\nclass RDB(nn.Module):\n    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):\n        super(RDB, self).__init__()\n        G0 = growRate0\n        G  = growRate\n        C  = nConvLayers\n        \n        convs = []\n        for c in range(C):\n            convs.append(RDB_Conv(G0 + c*G, G))\n        self.convs = nn.Sequential(*convs)\n        \n        # Local Feature Fusion\n        self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)\n\n    def forward(self, x):\n        return self.LFF(self.convs(x)) + x\n\nclass RDN(nn.Module):\n    def __init__(self, args):\n        super(RDN, self).__init__()\n        r = args.scale[0]\n        G0 = args.G0\n        kSize = args.RDNkSize\n\n        # number of RDB blocks, conv layers, out channels\n        self.D, C, G = {\n            'A': (20, 6, 32),\n            'B': (16, 8, 64),\n        }[args.RDNconfig]\n\n        # Shallow feature extraction net\n        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)\n        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n\n        # Redidual dense blocks and dense feature fusion\n        self.RDBs = nn.ModuleList()\n        for i in range(self.D):\n            self.RDBs.append(\n                RDB(growRate0 = G0, growRate = G, nConvLayers = C)\n            )\n\n        # Global Feature Fusion\n        self.GFF = nn.Sequential(*[\n            nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),\n            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n        ])\n\n        # Up-sampling net\n        if r == 2 or r == 3:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(r),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        elif r == 4:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        else:\n            raise ValueError(\"scale must be 2 or 3 or 4.\")\n\n    def forward(self, x):\n        f__1 = self.SFENet1(x)\n        x  = self.SFENet2(f__1)\n\n        RDBs_out = []\n        for i in range(self.D):\n            x = self.RDBs[i](x)\n            RDBs_out.append(x)\n\n        x = self.GFF(torch.cat(RDBs_out,1))\n        x += f__1\n\n        return self.UPNet(x)\n"
  },
  {
    "path": "src/model/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/model/utils/tools.py",
    "content": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nimport torch.nn.functional as F\n\ndef normalize(x):\n    return x.mul_(2).add_(-1)\n\ndef same_padding(images, ksizes, strides, rates):\n    assert len(images.size()) == 4\n    batch_size, channel, rows, cols = images.size()\n    out_rows = (rows + strides[0] - 1) // strides[0]\n    out_cols = (cols + strides[1] - 1) // strides[1]\n    effective_k_row = (ksizes[0] - 1) * rates[0] + 1\n    effective_k_col = (ksizes[1] - 1) * rates[1] + 1\n    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)\n    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)\n    # Pad the input\n    padding_top = int(padding_rows / 2.)\n    padding_left = int(padding_cols / 2.)\n    padding_bottom = padding_rows - padding_top\n    padding_right = padding_cols - padding_left\n    paddings = (padding_left, padding_right, padding_top, padding_bottom)\n    images = torch.nn.ZeroPad2d(paddings)(images)\n    return images\n\n\ndef extract_image_patches(images, ksizes, strides, rates, padding='same'):\n    \"\"\"\n    Extract patches from images and put them in the C output dimension.\n    :param padding:\n    :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape\n    :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for\n     each dimension of images\n    :param strides: [stride_rows, stride_cols]\n    :param rates: [dilation_rows, dilation_cols]\n    :return: A Tensor\n    \"\"\"\n    assert len(images.size()) == 4\n    assert padding in ['same', 'valid']\n    batch_size, channel, height, width = images.size()\n    \n    if padding == 'same':\n        images = same_padding(images, ksizes, strides, rates)\n    elif padding == 'valid':\n        pass\n    else:\n        raise NotImplementedError('Unsupported padding type: {}.\\\n                Only \"same\" or \"valid\" are supported.'.format(padding))\n\n    unfold = torch.nn.Unfold(kernel_size=ksizes,\n                             dilation=rates,\n                             padding=0,\n                             stride=strides)\n    patches = unfold(images)\n    return patches  # [N, C*k*k, L], L is the total number of such blocks\ndef reduce_mean(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.mean(x, dim=i, keepdim=keepdim)\n    return x\n\n\ndef reduce_std(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.std(x, dim=i, keepdim=keepdim)\n    return x\n\n\ndef reduce_sum(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.sum(x, dim=i, keepdim=keepdim)\n    return x\n\n"
  },
  {
    "path": "src/model/vdsr.py",
    "content": "from model import common\n\nimport torch.nn as nn\nimport torch.nn.init as init\n\nurl = {\n    'r20f64': ''\n}\n\ndef make_model(args, parent=False):\n    return VDSR(args)\n\nclass VDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(VDSR, self).__init__()\n\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3 \n        self.url = url['r{}f{}'.format(n_resblocks, n_feats)]\n        self.sub_mean = common.MeanShift(args.rgb_range)\n        self.add_mean = common.MeanShift(args.rgb_range, sign=1)\n\n        def basic_block(in_channels, out_channels, act):\n            return common.BasicBlock(\n                conv, in_channels, out_channels, kernel_size,\n                bias=True, bn=False, act=act\n            )\n\n        # define body module\n        m_body = []\n        m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))\n        for _ in range(n_resblocks - 2):\n            m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))\n        m_body.append(basic_block(n_feats, args.n_colors, None))\n\n        self.body = nn.Sequential(*m_body)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        res = self.body(x)\n        res += x\n        x = self.add_mean(res)\n\n        return x \n\n"
  },
  {
    "path": "src/option.py",
    "content": "import argparse\nimport template\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--debug', action='store_true',\n                    help='Enables debug mode')\nparser.add_argument('--template', default='.',\n                    help='You can set various templates in option.py')\n\n# Hardware specifications\nparser.add_argument('--n_threads', type=int, default=18,\n                    help='number of threads for data loading')\nparser.add_argument('--cpu', action='store_true',\n                    help='use cpu only')\nparser.add_argument('--n_GPUs', type=int, default=1,\n                    help='number of GPUs')\nparser.add_argument('--seed', type=int, default=1,\n                    help='random seed')\nparser.add_argument('--local_rank',type=int, default=0)\n# Data specifications\nparser.add_argument('--dir_data', type=str, default='../../../',\n                    help='dataset directory')\nparser.add_argument('--dir_demo', type=str, default='../Demo',\n                    help='demo image directory')\nparser.add_argument('--data_train', type=str, default='DIV2K',\n                    help='train dataset name')\nparser.add_argument('--data_test', type=str, default='DIV2K',\n                    help='test dataset name')\nparser.add_argument('--data_range', type=str, default='1-800/801-810',\n                    help='train/test data range')\nparser.add_argument('--ext', type=str, default='sep',\n                    help='dataset file extension')\nparser.add_argument('--scale', type=str, default='4',\n                    help='super resolution scale')\nparser.add_argument('--patch_size', type=int, default=192,\n                    help='output patch size')\nparser.add_argument('--rgb_range', type=int, default=255,\n                    help='maximum value of RGB')\nparser.add_argument('--n_colors', type=int, default=3,\n                    help='number of color channels to use')\nparser.add_argument('--chunk_size',type=int,default=144,\n                    help='attention bucket size')\nparser.add_argument('--n_hashes',type=int,default=4,\n                    help='number of hash rounds')\nparser.add_argument('--chop', action='store_true',\n                    help='enable memory-efficient forward')\nparser.add_argument('--no_augment', action='store_true',\n                    help='do not use data augmentation')\n\n# Model specifications\nparser.add_argument('--model', default='EDSR',\n                    help='model name')\n\nparser.add_argument('--act', type=str, default='relu',\n                    help='activation function')\nparser.add_argument('--pre_train', type=str, default='.',\n                    help='pre-trained model directory')\nparser.add_argument('--extend', type=str, default='.',\n                    help='pre-trained model directory')\nparser.add_argument('--n_resblocks', type=int, default=20,\n                    help='number of residual blocks')\nparser.add_argument('--n_feats', type=int, default=64,\n                    help='number of feature maps')\nparser.add_argument('--res_scale', type=float, default=1,\n                    help='residual scaling')\nparser.add_argument('--shift_mean', default=True,\n                    help='subtract pixel mean from the input')\nparser.add_argument('--dilation', action='store_true',\n                    help='use dilated convolution')\nparser.add_argument('--precision', type=str, default='single',\n                    choices=('single', 'half'),\n                    help='FP precision for test (single | half)')\n\n# Option for Residual dense network (RDN)\nparser.add_argument('--G0', type=int, default=64,\n                    help='default number of filters. (Use in RDN)')\nparser.add_argument('--RDNkSize', type=int, default=3,\n                    help='default kernel size. (Use in RDN)')\nparser.add_argument('--RDNconfig', type=str, default='B',\n                    help='parameters config of RDN. (Use in RDN)')\n\nparser.add_argument('--depth', type=int, default=12,\n                    help='number of residual groups')\n# Option for Residual channel attention network (RCAN)\nparser.add_argument('--n_resgroups', type=int, default=10,\n                    help='number of residual groups')\nparser.add_argument('--reduction', type=int, default=16,\n                    help='number of feature maps reduction')\n\n# Training specifications\nparser.add_argument('--reset', action='store_true',\n                    help='reset the training')\nparser.add_argument('--test_every', type=int, default=1000,\n                    help='do test per every N batches')\nparser.add_argument('--epochs', type=int, default=1000,\n                    help='number of epochs to train')\nparser.add_argument('--batch_size', type=int, default=16,\n                    help='input batch size for training')\nparser.add_argument('--split_batch', type=int, default=1,\n                    help='split the batch into smaller chunks')\nparser.add_argument('--self_ensemble', action='store_true',\n                    help='use self-ensemble method for test')\nparser.add_argument('--test_only', action='store_true',\n                    help='set this option to test the model')\nparser.add_argument('--gan_k', type=int, default=1,\n                    help='k value for adversarial loss')\n\n# Optimization specifications\nparser.add_argument('--lr', type=float, default=1e-4,\n                    help='learning rate')\nparser.add_argument('--decay', type=str, default='200',\n                    help='learning rate decay type')\nparser.add_argument('--gamma', type=float, default=0.5,\n                    help='learning rate decay factor for step decay')\nparser.add_argument('--optimizer', default='ADAM',\n                    choices=('SGD', 'ADAM', 'RMSprop'),\n                    help='optimizer to use (SGD | ADAM | RMSprop)')\nparser.add_argument('--momentum', type=float, default=0.9,\n                    help='SGD momentum')\nparser.add_argument('--betas', type=tuple, default=(0.9, 0.999),\n                    help='ADAM beta')\nparser.add_argument('--epsilon', type=float, default=1e-8,\n                    help='ADAM epsilon for numerical stability')\nparser.add_argument('--weight_decay', type=float, default=0,\n                    help='weight decay')\nparser.add_argument('--gclip', type=float, default=0,\n                    help='gradient clipping threshold (0 = no clipping)')\n\n# Loss specifications\nparser.add_argument('--loss', type=str, default='1*L1',\n                    help='loss function configuration')\nparser.add_argument('--skip_threshold', type=float, default='1e8',\n                    help='skipping batch that has large error')\n\n# Log specifications\nparser.add_argument('--save', type=str, default='test',\n                    help='file name to save')\nparser.add_argument('--load', type=str, default='',\n                    help='file name to load')\nparser.add_argument('--resume', type=int, default=0,\n                    help='resume from specific checkpoint')\nparser.add_argument('--save_models', action='store_true',\n                    help='save all intermediate models')\nparser.add_argument('--print_every', type=int, default=100,\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--save_results', action='store_true',\n                    help='save output results')\nparser.add_argument('--save_gt', action='store_true',\n                    help='save low-resolution and high-resolution images together')\n\nargs = parser.parse_args()\ntemplate.set_template(args)\n\nargs.scale = list(map(lambda x: int(x), args.scale.split('+')))\nargs.data_train = args.data_train.split('+')\nargs.data_test = args.data_test.split('+')\n\nif args.epochs == 0:\n    args.epochs = 1e8\n\nfor arg in vars(args):\n    if vars(args)[arg] == 'True':\n        vars(args)[arg] = True\n    elif vars(args)[arg] == 'False':\n        vars(args)[arg] = False\n\n"
  },
  {
    "path": "src/template.py",
    "content": "def set_template(args):\n    # Set the templates here\n    if args.template.find('jpeg') >= 0:\n        args.data_train = 'DIV2K_jpeg'\n        args.data_test = 'DIV2K_jpeg'\n        args.epochs = 200\n        args.decay = '100'\n\n    if args.template.find('EDSR_paper') >= 0:\n        args.model = 'EDSR'\n        args.n_resblocks = 32\n        args.n_feats = 256\n        args.res_scale = 0.1\n\n    if args.template.find('MDSR') >= 0:\n        args.model = 'MDSR'\n        args.patch_size = 48\n        args.epochs = 650\n\n    if args.template.find('DDBPN') >= 0:\n        args.model = 'DDBPN'\n        args.patch_size = 128\n        args.scale = '4'\n\n        args.data_test = 'Set5'\n\n        args.batch_size = 20\n        args.epochs = 1000\n        args.decay = '500'\n        args.gamma = 0.1\n        args.weight_decay = 1e-4\n\n        args.loss = '1*MSE'\n\n    if args.template.find('GAN') >= 0:\n        args.epochs = 200\n        args.lr = 5e-5\n        args.decay = '150'\n\n    if args.template.find('RCAN') >= 0:\n        args.model = 'RCAN'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('VDSR') >= 0:\n        args.model = 'VDSR'\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.patch_size = 41\n        args.lr = 1e-1\n\n"
  },
  {
    "path": "src/trainer.py",
    "content": "import os\nimport math\nfrom decimal import Decimal\n\nimport utility\n\nimport torch\nimport torch.nn.utils as utils\nfrom tqdm import tqdm\n\nclass Trainer():\n    def __init__(self, args, loader, my_model, my_loss, ckp):\n        self.args = args\n        self.scale = args.scale\n\n        self.ckp = ckp\n        self.loader_train = loader.loader_train\n        self.loader_test = loader.loader_test\n        self.model = my_model\n        self.loss = my_loss\n        self.optimizer = utility.make_optimizer(args, self.model)\n\n        if self.args.load != '':\n            self.optimizer.load(ckp.dir, epoch=len(ckp.log))\n\n        self.error_last = 1e8\n\n    def train(self):\n        self.loss.step()\n        epoch = self.optimizer.get_last_epoch() + 1\n        lr = self.optimizer.get_lr()\n\n        self.ckp.write_log(\n            '[Epoch {}]\\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))\n        )\n        self.loss.start_log()\n        self.model.train()\n\n        timer_data, timer_model = utility.timer(), utility.timer()\n        # TEMP\n        self.loader_train.dataset.set_scale(0)\n        for batch, (lr, hr, _,) in enumerate(self.loader_train):\n            lr, hr = self.prepare(lr, hr)\n            timer_data.hold()\n            timer_model.tic()\n\n            self.optimizer.zero_grad()\n            sr = self.model(lr, 0)\n            loss = self.loss(sr, hr)\n            loss.backward()\n            if self.args.gclip > 0:\n                utils.clip_grad_value_(\n                    self.model.parameters(),\n                    self.args.gclip\n                )\n            self.optimizer.step()\n\n            timer_model.hold()\n\n            if (batch + 1) % self.args.print_every == 0:\n                self.ckp.write_log('[{}/{}]\\t{}\\t{:.1f}+{:.1f}s'.format(\n                    (batch + 1) * self.args.batch_size,\n                    len(self.loader_train.dataset),\n                    self.loss.display_loss(batch),\n                    timer_model.release(),\n                    timer_data.release()))\n\n            timer_data.tic()\n\n        self.loss.end_log(len(self.loader_train))\n        self.error_last = self.loss.log[-1, -1]\n        self.optimizer.schedule()\n\n    def test(self):\n        torch.set_grad_enabled(False)\n\n        epoch = self.optimizer.get_last_epoch()\n        self.ckp.write_log('\\nEvaluation:')\n        self.ckp.add_log(\n            torch.zeros(1, len(self.loader_test), len(self.scale))\n        )\n        self.model.eval()\n\n        timer_test = utility.timer()\n        if self.args.save_results: self.ckp.begin_background()\n        for idx_data, d in enumerate(self.loader_test):\n            for idx_scale, scale in enumerate(self.scale):\n                d.dataset.set_scale(idx_scale)\n                for lr, hr, filename in tqdm(d, ncols=80):\n                    lr, hr = self.prepare(lr, hr)\n                    sr = self.model(lr, idx_scale)\n                    sr = utility.quantize(sr, self.args.rgb_range)\n\n                    save_list = [sr]\n                    self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(\n                        sr, hr, scale, self.args.rgb_range, dataset=d\n                    )\n                    if self.args.save_gt:\n                        save_list.extend([lr, hr])\n\n                    if self.args.save_results:\n                        self.ckp.save_results(d, filename[0], save_list, scale)\n\n                self.ckp.log[-1, idx_data, idx_scale] /= len(d)\n                best = self.ckp.log.max(0)\n                self.ckp.write_log(\n                    '[{} x{}]\\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(\n                        d.dataset.name,\n                        scale,\n                        self.ckp.log[-1, idx_data, idx_scale],\n                        best[0][idx_data, idx_scale],\n                        best[1][idx_data, idx_scale] + 1\n                    )\n                )\n\n        self.ckp.write_log('Forward: {:.2f}s\\n'.format(timer_test.toc()))\n        self.ckp.write_log('Saving...')\n\n        if self.args.save_results:\n            self.ckp.end_background()\n\n        if not self.args.test_only:\n            self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))\n\n        self.ckp.write_log(\n            'Total: {:.2f}s\\n'.format(timer_test.toc()), refresh=True\n        )\n\n        torch.set_grad_enabled(True)\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda')\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        return [_prepare(a) for a in args]\n\n    def terminate(self):\n        if self.args.test_only:\n            self.test()\n            return True\n        else:\n            epoch = self.optimizer.get_last_epoch() + 1\n            return epoch >= self.args.epochs\n\n"
  },
  {
    "path": "src/utility.py",
    "content": "import os\nimport math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as lrs\n\nclass timer():\n    def __init__(self):\n        self.acc = 0\n        self.tic()\n\n    def tic(self):\n        self.t0 = time.time()\n\n    def toc(self, restart=False):\n        diff = time.time() - self.t0\n        if restart: self.t0 = time.time()\n        return diff\n\n    def hold(self):\n        self.acc += self.toc()\n\n    def release(self):\n        ret = self.acc\n        self.acc = 0\n\n        return ret\n\n    def reset(self):\n        self.acc = 0\n\nclass checkpoint():\n    def __init__(self, args):\n        self.args = args\n        self.ok = True\n        self.log = torch.Tensor()\n        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')\n\n        if not args.load:\n            if not args.save:\n                args.save = now\n            self.dir = os.path.join('..', 'experiment', args.save)\n        else:\n            self.dir = os.path.join('..', 'experiment', args.load)\n            if os.path.exists(self.dir):\n                self.log = torch.load(self.get_path('psnr_log.pt'))\n                print('Continue from epoch {}...'.format(len(self.log)))\n            else:\n                args.load = ''\n\n        if args.reset:\n            os.system('rm -rf ' + self.dir)\n            args.load = ''\n\n        os.makedirs(self.dir, exist_ok=True)\n        os.makedirs(self.get_path('model'), exist_ok=True)\n        for d in args.data_test:\n            os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)\n\n        open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'\n        self.log_file = open(self.get_path('log.txt'), open_type)\n        with open(self.get_path('config.txt'), open_type) as f:\n            f.write(now + '\\n\\n')\n            for arg in vars(args):\n                f.write('{}: {}\\n'.format(arg, getattr(args, arg)))\n            f.write('\\n')\n\n        self.n_processes = 8\n\n    def get_path(self, *subdir):\n        return os.path.join(self.dir, *subdir)\n\n    def save(self, trainer, epoch, is_best=False):\n        trainer.model.save(self.get_path('model'), epoch, is_best=is_best)\n        trainer.loss.save(self.dir)\n        trainer.loss.plot_loss(self.dir, epoch)\n\n        self.plot_psnr(epoch)\n        trainer.optimizer.save(self.dir)\n        torch.save(self.log, self.get_path('psnr_log.pt'))\n\n    def add_log(self, log):\n        self.log = torch.cat([self.log, log])\n\n    def write_log(self, log, refresh=False):\n        print(log)\n        self.log_file.write(log + '\\n')\n        if refresh:\n            self.log_file.close()\n            self.log_file = open(self.get_path('log.txt'), 'a')\n\n    def done(self):\n        self.log_file.close()\n\n    def plot_psnr(self, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for idx_data, d in enumerate(self.args.data_test):\n            label = 'SR on {}'.format(d)\n            fig = plt.figure()\n            plt.title(label)\n            for idx_scale, scale in enumerate(self.args.scale):\n                plt.plot(\n                    axis,\n                    self.log[:, idx_data, idx_scale].numpy(),\n                    label='Scale {}'.format(scale)\n                )\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('PSNR')\n            plt.grid(True)\n            plt.savefig(self.get_path('test_{}.pdf'.format(d)))\n            plt.close(fig)\n\n    def begin_background(self):\n        self.queue = Queue()\n\n        def bg_target(queue):\n            while True:\n                if not queue.empty():\n                    filename, tensor = queue.get()\n                    if filename is None: break\n                    imageio.imwrite(filename, tensor.numpy())\n        \n        self.process = [\n            Process(target=bg_target, args=(self.queue,)) \\\n            for _ in range(self.n_processes)\n        ]\n        \n        for p in self.process: p.start()\n\n    def end_background(self):\n        for _ in range(self.n_processes): self.queue.put((None, None))\n        while not self.queue.empty(): time.sleep(1)\n        for p in self.process: p.join()\n\n    def save_results(self, dataset, filename, save_list, scale):\n        if self.args.save_results:\n            filename = self.get_path(\n                'results-{}'.format(dataset.dataset.name),\n                '{}_x{}_'.format(filename, scale)\n            )\n\n            postfix = ('SR', 'LR', 'HR')\n            for v, p in zip(save_list, postfix):\n                normalized = v[0].mul(255 / self.args.rgb_range)\n                tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()\n                self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))\n\ndef quantize(img, rgb_range):\n    pixel_range = 255 / rgb_range\n    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)\n\ndef calc_psnr(sr, hr, scale, rgb_range, dataset=None):\n    if hr.nelement() == 1: return 0\n\n    diff = (sr - hr) / rgb_range\n    if dataset and dataset.dataset.benchmark:\n        shave = scale\n        if diff.size(1) > 1:\n            gray_coeffs = [65.738, 129.057, 25.064]\n            convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256\n            diff = diff.mul(convert).sum(dim=1)\n    else:\n        shave = scale + 6\n\n    valid = diff[..., shave:-shave, shave:-shave]\n    mse = valid.pow(2).mean()\n\n    return -10 * math.log10(mse)\n\ndef make_optimizer(args, target):\n    '''\n        make optimizer and scheduler together\n    '''\n    # optimizer\n    trainable = filter(lambda x: x.requires_grad, target.parameters())\n    kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}\n\n    if args.optimizer == 'SGD':\n        optimizer_class = optim.SGD\n        kwargs_optimizer['momentum'] = args.momentum\n    elif args.optimizer == 'ADAM':\n        optimizer_class = optim.Adam\n        kwargs_optimizer['betas'] = args.betas\n        kwargs_optimizer['eps'] = args.epsilon\n    elif args.optimizer == 'RMSprop':\n        optimizer_class = optim.RMSprop\n        kwargs_optimizer['eps'] = args.epsilon\n\n    # scheduler\n    milestones = list(map(lambda x: int(x), args.decay.split('-')))\n    kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}\n    scheduler_class = lrs.MultiStepLR\n\n    class CustomOptimizer(optimizer_class):\n        def __init__(self, *args, **kwargs):\n            super(CustomOptimizer, self).__init__(*args, **kwargs)\n\n        def _register_scheduler(self, scheduler_class, **kwargs):\n            self.scheduler = scheduler_class(self, **kwargs)\n\n        def save(self, save_dir):\n            torch.save(self.state_dict(), self.get_dir(save_dir))\n\n        def load(self, load_dir, epoch=1):\n            self.load_state_dict(torch.load(self.get_dir(load_dir)))\n            if epoch > 1:\n                for _ in range(epoch): self.scheduler.step()\n\n        def get_dir(self, dir_path):\n            return os.path.join(dir_path, 'optimizer.pt')\n\n        def schedule(self):\n            self.scheduler.step()\n\n        def get_lr(self):\n            return self.scheduler.get_lr()[0]\n\n        def get_last_epoch(self):\n            return self.scheduler.last_epoch\n    \n    optimizer = CustomOptimizer(trainable, **kwargs_optimizer)\n    optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)\n    return optimizer\n\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/utils/tools.py",
    "content": "import os\nimport torch\nimport numpy as np\nfrom PIL import Image\n\nimport torch.nn.functional as F\n\ndef normalize(x):\n    return x.mul_(2).add_(-1)\n\ndef same_padding(images, ksizes, strides, rates):\n    assert len(images.size()) == 4\n    batch_size, channel, rows, cols = images.size()\n    out_rows = (rows + strides[0] - 1) // strides[0]\n    out_cols = (cols + strides[1] - 1) // strides[1]\n    effective_k_row = (ksizes[0] - 1) * rates[0] + 1\n    effective_k_col = (ksizes[1] - 1) * rates[1] + 1\n    padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)\n    padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)\n    # Pad the input\n    padding_top = int(padding_rows / 2.)\n    padding_left = int(padding_cols / 2.)\n    padding_bottom = padding_rows - padding_top\n    padding_right = padding_cols - padding_left\n    paddings = (padding_left, padding_right, padding_top, padding_bottom)\n    images = torch.nn.ZeroPad2d(paddings)(images)\n    return images\n\n\ndef extract_image_patches(images, ksizes, strides, rates, padding='same'):\n    \"\"\"\n    Extract patches from images and put them in the C output dimension.\n    :param padding:\n    :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape\n    :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for\n     each dimension of images\n    :param strides: [stride_rows, stride_cols]\n    :param rates: [dilation_rows, dilation_cols]\n    :return: A Tensor\n    \"\"\"\n    assert len(images.size()) == 4\n    assert padding in ['same', 'valid']\n    batch_size, channel, height, width = images.size()\n    \n    if padding == 'same':\n        images = same_padding(images, ksizes, strides, rates)\n    elif padding == 'valid':\n        pass\n    else:\n        raise NotImplementedError('Unsupported padding type: {}.\\\n                Only \"same\" or \"valid\" are supported.'.format(padding))\n\n    unfold = torch.nn.Unfold(kernel_size=ksizes,\n                             dilation=rates,\n                             padding=0,\n                             stride=strides)\n    patches = unfold(images)\n    return patches  # [N, C*k*k, L], L is the total number of such blocks\ndef reduce_mean(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.mean(x, dim=i, keepdim=keepdim)\n    return x\n\n\ndef reduce_std(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.std(x, dim=i, keepdim=keepdim)\n    return x\n\n\ndef reduce_sum(x, axis=None, keepdim=False):\n    if not axis:\n        axis = range(len(x.shape))\n    for i in sorted(axis, reverse=True):\n        x = torch.sum(x, dim=i, keepdim=keepdim)\n    return x\n\n"
  },
  {
    "path": "src/videotester.py",
    "content": "import os\nimport math\n\nimport utility\nfrom data import common\n\nimport torch\nimport cv2\n\nfrom tqdm import tqdm\n\nclass VideoTester():\n    def __init__(self, args, my_model, ckp):\n        self.args = args\n        self.scale = args.scale\n\n        self.ckp = ckp\n        self.model = my_model\n\n        self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))\n\n    def test(self):\n        torch.set_grad_enabled(False)\n\n        self.ckp.write_log('\\nEvaluation on video:')\n        self.model.eval()\n\n        timer_test = utility.timer()\n        for idx_scale, scale in enumerate(self.scale):\n            vidcap = cv2.VideoCapture(self.args.dir_demo)\n            total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n            vidwri = cv2.VideoWriter(\n                self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),\n                cv2.VideoWriter_fourcc(*'XVID'),\n                vidcap.get(cv2.CAP_PROP_FPS),\n                (\n                    int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n                    int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n                )\n            )\n\n            tqdm_test = tqdm(range(total_frames), ncols=80)\n            for _ in tqdm_test:\n                success, lr = vidcap.read()\n                if not success: break\n\n                lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n                lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n                lr, = self.prepare(lr.unsqueeze(0))\n                sr = self.model(lr, idx_scale)\n                sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)\n\n                normalized = sr * 255 / self.args.rgb_range\n                ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()\n                vidwri.write(ndarr)\n\n            vidcap.release()\n            vidwri.release()\n\n        self.ckp.write_log(\n            'Total: {:.2f}s\\n'.format(timer_test.toc()), refresh=True\n        )\n        torch.set_grad_enabled(True)\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda')\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        return [_prepare(a) for a in args]\n\n"
  }
]