main 51524af8021c cached
33 files
91.3 MB
23.5k tokens
154 symbols
1 requests
Download .txt
Repository: The-Learning-And-Vision-Atelier-LAVA/DASR
Branch: main
Commit: 51524af8021c
Files: 33
Total size: 91.3 MB

Directory structure:
gitextract_01sw5xtg/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── benchmark.py
│   ├── common.py
│   ├── df2k.py
│   └── multiscalesrdata.py
├── dataloader.py
├── experiment/
│   ├── blindsr_x2_bicubic_iso/
│   │   └── model/
│   │       └── model_600.pt
│   ├── blindsr_x3_bicubic_iso/
│   │   └── model/
│   │       └── model_600.pt
│   ├── blindsr_x4_bicubic_aniso/
│   │   └── model/
│   │       └── model_600.pt
│   └── blindsr_x4_bicubic_iso/
│       └── model/
│           └── model_600.pt
├── loss/
│   ├── __init__.py
│   ├── adversarial.py
│   ├── discriminator.py
│   └── vgg.py
├── main.py
├── main.sh
├── moco/
│   ├── __init__.py
│   └── builder.py
├── model/
│   ├── __init__.py
│   ├── blindsr.py
│   └── common.py
├── option.py
├── quick_test.py
├── quick_test.sh
├── template.py
├── test.py
├── test.sh
├── trainer.py
├── utility.py
└── utils/
    ├── __init__.py
    └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 The Learning and Vision Atelier (LAVA)

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# DASR
Pytorch implementation of "Unsupervised Degradation Representation Learning for Blind Super-Resolution", CVPR 2021

[[arXiv]](http://arxiv.org/pdf/2104.00416) [[CVF]](https://openaccess.thecvf.com/content/CVPR2021/papers/Wang_Unsupervised_Degradation_Representation_Learning_for_Blind_Super-Resolution_CVPR_2021_paper.pdf) [[Supp]](https://openaccess.thecvf.com/content/CVPR2021/supplemental/Wang_Unsupervised_Degradation_Representation_CVPR_2021_supplemental.pdf)


## Overview

<p align="center"> <img src="Figs/fig.1.png" width="50%"> </p>


<p align="center"> <img src="Figs/fig.2.png" width="100%"> </p>


## Requirements
- Python 3.6
- PyTorch == 1.1.0
- numpy
- skimage
- imageio
- matplotlib
- cv2


## Train
### 1. Prepare training data 

1.1 Download the [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)  dataset and the [Flickr2K](http://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) dataset.

1.2 Combine the HR images from these two datasets in `your_data_path/DF2K/HR` to build the DF2K dataset. 

### 2. Begin to train
Run `./main.sh` to train on the DF2K dataset. Please update `dir_data` in the bash file as `your_data_path`.


## Test
### 1. Prepare test data 
Download [benchmark datasets](https://github.com/xinntao/BasicSR/blob/a19aac61b277f64be050cef7fe578a121d944a0e/docs/Datasets.md) (e.g., Set5, Set14 and other test sets) and prepare HR/LR images in `your_data_path/benchmark`.


### 2. Begin to test
Run `./test.sh` to test on benchmark datasets. Please update `dir_data` in the bash file as `your_data_path`.


## Quick Test on An LR Image
Run `./quick_test.sh` to test on an LR image. Please update `img_dir` in the bash file as `your_img_path`.

## Visualization of Degradation Representations
<p align="center"> <img src="Figs/fig.6.png" width="50%"> </p>

## Comparative Results
### Noise-Free Degradations with Isotropic Gaussian Kernels

<p align="center"> <img src="Figs/tab2.png" width="100%"> </p>

<p align="center"> <img src="Figs/fig.5.png" width="100%"> </p>


### General Degradations with Anisotropic Gaussian Kernels and Noises
<p align="center"> <img src="Figs/tab3.png" width="100%"> </p>

<p align="center"> <img src="Figs/fig.7.png" width="100%"> </p>

### Unseen Degradations 

<p align="center"> <img src="Figs/fig.III.png" width="50%"> </p>

### Real Degradations (AIM real-world SR challenge)

<p align="center"> <img src="Figs/fig.VII.png" width="50%"> </p>

## Citation
```
@InProceedings{Wang2021Unsupervised,
  author    = {Wang, Longguang and Wang, Yingqian and Dong, Xiaoyu and Xu, Qingyu and Yang, Jungang and An, Wei and Guo, Yulan},
  title     = {Unsupervised Degradation Representation Learning for Blind Super-Resolution},
  booktitle = {CVPR},
  year      = {2021},
}
```

## Acknowledgements
This code is built on [EDSR (PyTorch)](https://github.com/thstkdgus35/EDSR-PyTorch), [IKC](https://github.com/yuanjunchai/IKC) and [MoCo](https://github.com/facebookresearch/moco). We thank the authors for sharing the codes.



================================================
FILE: data/__init__.py
================================================
from importlib import import_module
from dataloader import MSDataLoader

class Data:
    def __init__(self, args):
        self.loader_train = None
        if not args.test_only:
            module_train = import_module('data.' + args.data_train.lower())     ## load the right dataset loader module
            trainset = getattr(module_train, args.data_train)(args)             ## load the dataset, args.data_train is the  dataset name
            self.loader_train = MSDataLoader(
                args,
                trainset,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=not args.cpu
            )

        if args.data_test in ['Set5', 'Set14', 'B100', 'Manga109', 'Urban100']:
            module_test = import_module('data.benchmark')
            testset = getattr(module_test, 'Benchmark')(args, name=args.data_test,train=False)
        else:
            module_test = import_module('data.' + args.data_test.lower())
            testset = getattr(module_test, args.data_test)(args, train=False)

        self.loader_test = MSDataLoader(
            args,
            testset,
            batch_size=1,
            shuffle=False,
            pin_memory=not args.cpu
        )



================================================
FILE: data/benchmark.py
================================================
import os
from data import common
from data import multiscalesrdata as srdata


class Benchmark(srdata.SRData):
    def __init__(self, args, name='', train=True):
        super(Benchmark, self).__init__(
            args, name=name, train=train, benchmark=True
        )

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data,'benchmark', self.name)
        self.dir_hr = os.path.join(self.apath, 'HR')
        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
        self.ext = ('.png','.png')
        print(self.dir_hr)
        print(self.dir_lr)


================================================
FILE: data/common.py
================================================
import random
import numpy as np
import skimage.color as sc
import torch


def get_patch(img, patch_size=48, scale=1):
    th, tw = img.shape[:2]  ## HR image

    tp = round(scale * patch_size)

    tx = random.randrange(0, (tw-tp))
    ty = random.randrange(0, (th-tp))

    return img[ty:ty + tp, tx:tx + tp, :]


def set_channel(img, n_channels=3):
    if img.ndim == 2:
        img = np.expand_dims(img, axis=2)

    c = img.shape[2]
    if n_channels == 1 and c == 3:
        img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
    elif n_channels == 3 and c == 1:
        img = np.concatenate([img] * n_channels, 2)

    return img


def np2Tensor(img, rgb_range=255):
    np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
    tensor = torch.from_numpy(np_transpose).float()
    tensor.mul_(rgb_range / 255)

    return tensor


def augment(img, hflip=True, rot=True):
    hflip = hflip and random.random() < 0.5
    vflip = rot and random.random() < 0.5
    rot90 = rot and random.random() < 0.5

    if hflip: img = img[:, ::-1, :]
    if vflip: img = img[::-1, :, :]
    if rot90: img = img.transpose(1, 0, 2)

    return img



================================================
FILE: data/df2k.py
================================================
import os
from data import multiscalesrdata


class DF2K(multiscalesrdata.SRData):
    def __init__(self, args, name='DF2K', train=True, benchmark=False):
        super(DF2K, self).__init__(args, name=name, train=train, benchmark=benchmark)

    def _scan(self):
        names_hr = super(DF2K, self)._scan()
        names_hr = names_hr[self.begin - 1:self.end]

        return names_hr

    def _set_filesystem(self, dir_data):
        super(DF2K, self)._set_filesystem(dir_data)
        self.dir_hr = os.path.join(self.apath, 'HR')
        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')



================================================
FILE: data/multiscalesrdata.py
================================================
import os
import glob

from data import common
import pickle
import numpy as np
import imageio

import torch
import torch.utils.data as data


class SRData(data.Dataset):
    def __init__(self, args, name='', train=True, benchmark=False):
        self.args = args
        self.name = name
        self.train = train
        self.split = 'train' if train else 'test'
        self.do_eval = True
        self.benchmark = benchmark
        self.scale = args.scale
        self.idx_scale = 0

        data_range = [r.split('-') for r in args.data_range.split('/')]
        if train:
            data_range = data_range[0]
        else:
            if args.test_only and len(data_range) == 1:
                data_range = data_range[0]
            else:
                data_range = data_range[1]
        self.begin, self.end = list(map(lambda x: int(x), data_range))
        self._set_filesystem(args.dir_data)
        if args.ext.find('img') < 0:
            path_bin = os.path.join(self.apath, 'bin')
            os.makedirs(path_bin, exist_ok=True)

        list_hr = self._scan()
        if args.ext.find('bin') >= 0:
            # Binary files are stored in 'bin' folder
            # If the binary file exists, load it. If not, make it.
            list_hr = self._scan()
            self.images_hr = self._check_and_load(
                args.ext, list_hr, self._name_hrbin()
            )
        else:
            if args.ext.find('img') >= 0 or benchmark:
                self.images_hr = list_hr
            elif args.ext.find('sep') >= 0:
                os.makedirs(
                    self.dir_hr.replace(self.apath, path_bin),
                    exist_ok=True
                )

                self.images_hr = []
                for h in list_hr:
                    b = h.replace(self.apath, path_bin)
                    b = b.replace(self.ext[0], '.pt')
                    self.images_hr.append(b)
                    self._check_and_load(
                        args.ext, [h], b, verbose=True, load=False
                    )

        if train:
            self.repeat = args.test_every // (len(self.images_hr) // args.batch_size)

    # Below functions as used to prepare images
    def _scan(self):
        names_hr = sorted(
            glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
        )
        print(len(names_hr))

        return names_hr

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data, self.name)
        self.dir_hr = os.path.join(self.apath, 'HR')
        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
        self.ext = ('.png', '.png')

    def _name_hrbin(self):
        return os.path.join(
            self.apath,
            'bin',
            '{}_bin_HR.pt'.format(self.split)
        )

    def _name_lrbin(self, scale):
        return os.path.join(
            self.apath,
            'bin',
            '{}_bin_LR_X{}.pt'.format(self.split, scale)
        )

    def _check_and_load(self, ext, l, f, verbose=True, load=True):
        if os.path.isfile(f) and ext.find('reset') < 0:
            if load:
                if verbose: print('Loading {}...'.format(f))
                with open(f, 'rb') as _f:
                    ret = pickle.load(_f)
                return ret
            else:
                return None
        else:
            if verbose:
                if ext.find('reset') >= 0:
                    print('Making a new binary: {}'.format(f))
                else:
                    print('{} does not exist. Now making binary...'.format(f))
            b = [{
                'name': os.path.splitext(os.path.basename(_l))[0],
                'image': imageio.imread(_l)
            } for _l in l]
            with open(f, 'wb') as _f:
                pickle.dump(b, _f)
            return b

    def __getitem__(self, idx):
        hr, filename = self._load_file(idx)
        hr = self.get_patch(hr)
        hr = [common.set_channel(img, n_channels=self.args.n_colors) for img in hr]
        hr_tensor = [common.np2Tensor(img, rgb_range=self.args.rgb_range)
                     for img in hr]

        return torch.stack(hr_tensor, 0), filename

    def __len__(self):
        if self.train:
            return len(self.images_hr) * self.repeat
        else:
            return len(self.images_hr)

    def _get_index(self, idx):
        if self.train:
            return idx % len(self.images_hr)
        else:
            return idx

    def _load_file(self, idx):
        idx = self._get_index(idx)
        f_hr = self.images_hr[idx]

        if self.args.ext.find('bin') >= 0:
            filename = f_hr['name']
            hr = f_hr['image']
        else:
            filename, _ = os.path.splitext(os.path.basename(f_hr))
            if self.args.ext == 'img' or self.benchmark:
                hr = imageio.imread(f_hr)
            elif self.args.ext.find('sep') >= 0:
                with open(f_hr, 'rb') as _f:
                    hr = np.load(_f)[0]['image']

        return hr, filename

    def get_patch(self, hr):
        scale = self.scale[self.idx_scale]
        if self.train:
            out = []
            hr = common.augment(hr) if not self.args.no_augment else hr
            # extract two patches from each image
            for _ in range(2):
                hr_patch = common.get_patch(
                    hr,
                    patch_size=self.args.patch_size,
                    scale=scale
                )
                out.append(hr_patch)
        else:
            out = [hr]
        return out

    def set_scale(self, idx_scale):
        self.idx_scale = idx_scale



================================================
FILE: dataloader.py
================================================
import sys
import threading
import queue
import random
import collections

import torch
import torch.multiprocessing as multiprocessing

from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data import _utils

if sys.version_info[0] == 2:
    import Queue as queue
else:
    import queue

def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
    global _use_shared_memory
    _use_shared_memory = True
    _set_worker_signal_handlers()

    torch.set_num_threads(1)
    torch.manual_seed(seed)
    while True:
        r = index_queue.get()
        if r is None:
            break
        idx, batch_indices = r
        try:
            idx_scale = 0
            if len(scale) > 1 and dataset.train:
                idx_scale = random.randrange(0, len(scale))
                dataset.set_scale(idx_scale)

            samples = collate_fn([dataset[i] for i in batch_indices])
            samples.append(idx_scale)

        except Exception:
            data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))
        else:
            data_queue.put((idx, samples))

class _MSDataLoaderIter(_DataLoaderIter):
    def __init__(self, loader):
        self.dataset = loader.dataset
        self.scale = loader.scale
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout
        self.done_event = threading.Event()

        self.sample_iter = iter(self.batch_sampler)

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.index_queues = [
                multiprocessing.Queue() for _ in range(self.num_workers)
            ]
            self.worker_queue_idx = 0
            self.worker_result_queue = multiprocessing.Queue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}

            base_seed = torch.LongTensor(1).random_()[0]
            self.workers = [
                multiprocessing.Process(
                    target=_ms_loop,
                    args=(
                        self.dataset,
                        self.index_queues[i],
                        self.worker_result_queue,
                        self.collate_fn,
                        self.scale,
                        base_seed + i,
                        self.worker_init_fn,
                        i
                    )
                )
                for i in range(self.num_workers)]

            if self.pin_memory or self.timeout > 0:
                self.data_queue = queue.Queue()
                if self.pin_memory:
                    maybe_device_id = torch.cuda.current_device()
                else:
                    # do not initialize cuda context if not necessary
                    maybe_device_id = None
                self.pin_memory_thread = threading.Thread(
                    target=_utils.pin_memory._pin_memory_loop,
                    args=(self.worker_result_queue, self.data_queue, maybe_device_id, self.done_event))
                self.pin_memory_thread.daemon = True
                self.pin_memory_thread.start()
            else:
                self.data_queue = self.worker_result_queue

            for w in self.workers:
                w.daemon = True  # ensure that the worker exits on process exit
                w.start()

            _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
            _utils.signal_handling._set_SIGCHLD_handler()
            self.worker_pids_set = True

            # prime the prefetch loop
            for _ in range(2 * self.num_workers):
                self._put_indices()

class MSDataLoader(DataLoader):
    def __init__(
        self, args, dataset, batch_size=1, shuffle=False,
        sampler=None, batch_sampler=None,
        collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=True,
        timeout=0, worker_init_fn=None):

        super(MSDataLoader, self).__init__(
            dataset, batch_size=batch_size, shuffle=shuffle,
            sampler=sampler, batch_sampler=batch_sampler,
            num_workers=args.n_threads, collate_fn=collate_fn,
            pin_memory=pin_memory, drop_last=drop_last,
            timeout=timeout, worker_init_fn=worker_init_fn)

        self.scale = args.scale

    def __iter__(self):
        return _MSDataLoaderIter(self)


================================================
FILE: experiment/blindsr_x2_bicubic_iso/model/model_600.pt
================================================
[File too large to display: 22.3 MB]

================================================
FILE: experiment/blindsr_x3_bicubic_iso/model/model_600.pt
================================================
[File too large to display: 23.0 MB]

================================================
FILE: experiment/blindsr_x4_bicubic_aniso/model/model_600.pt
================================================
[File too large to display: 22.9 MB]

================================================
FILE: experiment/blindsr_x4_bicubic_iso/model/model_600.pt
================================================
[File too large to display: 22.9 MB]

================================================
FILE: loss/__init__.py
================================================
import os
from importlib import import_module

import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F


class Loss(nn.modules.loss._Loss):
    def __init__(self, args, ckp):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.n_GPUs = args.n_GPUs
        self.loss = []
        self.loss_module = nn.ModuleList()
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'MSE':
                loss_function = nn.MSELoss()
            elif loss_type == 'L1':
                loss_function = nn.L1Loss()
            elif loss_type == 'CE':
                loss_function = nn.CrossEntropyLoss()
            elif loss_type.find('VGG') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module, 'VGG')(
                    loss_type[3:],
                    rgb_range=args.rgb_range
                )
            elif loss_type.find('GAN') >= 0:
                module = import_module('loss.adversarial')
                loss_function = getattr(module, 'Adversarial')(
                    args,
                    loss_type
                )
           
            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function}
            )
            if loss_type.find('GAN') >= 0:
                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.3f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        self.log = torch.Tensor()

        device = torch.device('cpu' if args.cpu else 'cuda')
        self.loss_module.to(device)
        if args.precision == 'half': self.loss_module.half()
        if not args.cpu and args.n_GPUs > 1:
            self.loss_module = nn.DataParallel(
                self.loss_module, range(args.n_GPUs)
            )

        if args.load != '.': self.load(ckp.dir, cpu=args.cpu)

    def forward(self, sr, hr):
        losses = []
        for i, l in enumerate(self.loss):
            if l['function'] is not None:
                loss = l['function'](sr, hr)
                effective_loss = l['weight'] * loss
                losses.append(effective_loss)
                self.log[-1, i] += effective_loss.item()
            elif l['type'] == 'DIS':
                self.log[-1, i] += self.loss[i - 1]['function'].loss

        loss_sum = sum(losses)
        if len(self.loss) > 1:
            self.log[-1, -1] += loss_sum.item()

        return loss_sum

    def step(self):
        for l in self.get_loss_module():
            if hasattr(l, 'scheduler'):
                l.scheduler.step()

    def start_log(self):
        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))

    def end_log(self, n_batches):
        self.log[-1].div_(n_batches)

    def display_loss(self, batch):
        n_samples = batch + 1
        log = []
        for l, c in zip(self.loss, self.log[-1]):
            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))

        return ''.join(log)

    def plot_loss(self, apath, epoch):
        axis = np.linspace(1, epoch, epoch)
        for i, l in enumerate(self.loss):
            label = '{} Loss'.format(l['type'])
            fig = plt.figure()
            plt.title(label)
            plt.plot(axis, self.log[:, i].numpy(), label=label)
            plt.legend()
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
            plt.close(fig)

    def get_loss_module(self):
        if self.n_GPUs == 1:
            return self.loss_module
        else:
            return self.loss_module.module

    def save(self, apath):
        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))

    def load(self, apath, cpu=False):
        if cpu:
            kwargs = {'map_location': lambda storage, loc: storage}
        else:
            kwargs = {}

        self.load_state_dict(torch.load(
            os.path.join(apath, 'loss.pt'),
            **kwargs
        ))
        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
        for l in self.loss_module:
            if hasattr(l, 'scheduler'):
                for _ in range(len(self.log)): l.scheduler.step()

================================================
FILE: loss/adversarial.py
================================================
import utility
from model import common
from loss import discriminator

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class Adversarial(nn.Module):
    def __init__(self, args, gan_type):
        super(Adversarial, self).__init__()
        self.gan_type = gan_type
        self.gan_k = args.gan_k
        self.discriminator = discriminator.Discriminator(args, gan_type)
        if gan_type != 'WGAN_GP':
            self.optimizer = utility.make_optimizer(args, self.discriminator)
        else:
            self.optimizer = optim.Adam(
                self.discriminator.parameters(),
                betas=(0, 0.9), eps=1e-8, lr=1e-5
            )
        self.scheduler = utility.make_scheduler(args, self.optimizer)

    def forward(self, fake, real):
        fake_detach = fake.detach()

        self.loss = 0
        for _ in range(self.gan_k):
            self.optimizer.zero_grad()
            d_fake = self.discriminator(fake_detach)
            d_real = self.discriminator(real)
            if self.gan_type == 'GAN':
                label_fake = torch.zeros_like(d_fake)
                label_real = torch.ones_like(d_real)
                loss_d \
                    = F.binary_cross_entropy_with_logits(d_fake, label_fake) \
                    + F.binary_cross_entropy_with_logits(d_real, label_real)
            elif self.gan_type.find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if self.gan_type.find('GP') >= 0:
                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = self.discriminator(hat)
                    gradients = torch.autograd.grad(
                        outputs=d_hat.sum(), inputs=hat,
                        retain_graph=True, create_graph=True, only_inputs=True
                    )[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty

            # Discriminator update
            self.loss += loss_d.item()
            loss_d.backward()
            self.optimizer.step()

            if self.gan_type == 'WGAN':
                for p in self.discriminator.parameters():
                    p.data.clamp_(-1, 1)

        self.loss /= self.gan_k

        d_fake_for_g = self.discriminator(fake)
        if self.gan_type == 'GAN':
            loss_g = F.binary_cross_entropy_with_logits(
                d_fake_for_g, label_real
            )
        elif self.gan_type.find('WGAN') >= 0:
            loss_g = -d_fake_for_g.mean()

        # Generator loss
        return loss_g
    
    def state_dict(self, *args, **kwargs):
        state_discriminator = self.discriminator.state_dict(*args, **kwargs)
        state_optimizer = self.optimizer.state_dict()

        return dict(**state_discriminator, **state_optimizer)
               
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py


================================================
FILE: loss/discriminator.py
================================================
from model import common

import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, args, gan_type='GAN'):
        super(Discriminator, self).__init__()

        in_channels = 3
        out_channels = 64
        depth = 7
        #bn = not gan_type == 'WGAN_GP'
        bn = True
        act = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        m_features = [
            common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
        ]
        for i in range(depth):
            in_channels = out_channels
            if i % 2 == 1:
                stride = 1
                out_channels *= 2
            else:
                stride = 2
            m_features.append(common.BasicBlock(
                in_channels, out_channels, 3, stride=stride, bn=bn, act=act
            ))

        self.features = nn.Sequential(*m_features)

        patch_size = args.patch_size // (2**((depth + 1) // 2))
        m_classifier = [
            nn.Linear(out_channels * patch_size**2, 1024),
            act,
            nn.Linear(1024, 1)
        ]
        self.classifier = nn.Sequential(*m_classifier)

    def forward(self, x):
        features = self.features(x)
        output = self.classifier(features.view(features.size(0), -1))

        return output



================================================
FILE: loss/vgg.py
================================================
from model import common

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable

class VGG(nn.Module):
    def __init__(self, conv_index, rgb_range=1):
        super(VGG, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m for m in vgg_features]
        if conv_index == '22':
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index == '54':
            self.vgg = nn.Sequential(*modules[:35])

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        self.vgg.requires_grad = False

    def forward(self, sr, hr):
        def _forward(x):
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x
            
        vgg_sr = _forward(sr)
        with torch.no_grad():
            vgg_hr = _forward(hr.detach())

        loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss


================================================
FILE: main.py
================================================
from option import args
import torch
import utility
import data
import model
import loss
from trainer import Trainer


if __name__ == '__main__':
    torch.manual_seed(args.seed)
    checkpoint = utility.checkpoint(args)
    if checkpoint.ok:
        loader = data.Data(args)
        model = model.Model(args, checkpoint)
        loss = loss.Loss(args, checkpoint) if not args.test_only else None
        t = Trainer(args, loader, model, loss, checkpoint)
        while not t.terminate():
            t.train()

        checkpoint.done()


================================================
FILE: main.sh
================================================
# noise-free degradations with isotropic Gaussian blurs
python main.py --dir_data='D:/LongguangWang/Data' \
               --model='blindsr' \
               --scale='2' \
               --blur_type='iso_gaussian' \
               --noise=0.0 \
               --sig_min=0.2 \
               --sig_max=4.0


# general degradations with anisotropic Gaussian blurs and noises
python main.py --dir_data='D:/LongguangWang/Data' \
               --model='blindsr' \
               --scale='4' \
               --blur_type='aniso_gaussian' \
               --noise=25.0 \
               --lambda_min=0.2 \
               --lambda_max=4.0

cmd /k


================================================
FILE: moco/__init__.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


================================================
FILE: moco/builder.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn


class MoCo(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, dim=256, K=32*256, m=0.999, T=0.07, mlp=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder()
        self.encoder_k = base_encoder()

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        # keys = concat_all_gather(keys)
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]

    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """
        if self.training:
            # compute query features
            embedding, q = self.encoder_q(im_q)  # queries: NxC
            q = nn.functional.normalize(q, dim=1)

            # compute key features
            with torch.no_grad():  # no gradient to keys
                self._momentum_update_key_encoder()  # update the key encoder

                _, k = self.encoder_k(im_k)  # keys: NxC
                k = nn.functional.normalize(k, dim=1)

            # compute logits
            # Einstein sum is more intuitive
            # positive logits: Nx1
            l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
            # negative logits: NxK
            l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

            # logits: Nx(1+K)
            logits = torch.cat([l_pos, l_neg], dim=1)

            # apply temperature
            logits /= self.T

            # labels: positive key indicators
            labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

            # dequeue and enqueue
            self._dequeue_and_enqueue(k)

            return embedding, logits, labels
        else:
            embedding, _ = self.encoder_q(im_q)

            return embedding


# utils
@torch.no_grad()
def concat_all_gather(tensor):
    """
    Performs all_gather operation on the provided tensors.
    *** Warning ***: torch.distributed.all_gather has no gradient.
    """
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]
    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

    output = torch.cat(tensors_gather, dim=0)
    return output


================================================
FILE: model/__init__.py
================================================
import os
from importlib import import_module

import torch
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')
        self.args = args
        self.scale = args.scale
        self.idx_scale = 0
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models
        self.save = args.save

        module = import_module('model.'+args.model)
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half': self.model.half()

        if not args.cpu and args.n_GPUs > 1:
            self.model = nn.DataParallel(self.model, range(args.n_GPUs))

        self.load(
            ckp.dir,
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )

    def forward(self, x):
        if self.self_ensemble and not self.training:
            if self.chop:
                forward_function = self.forward_chop
            else:
                forward_function = self.model.forward

            return self.forward_x8(x, forward_function)
        elif self.chop and not self.training:
            return self.forward_chop(x)
        else:
            return self.model(x)

    def get_model(self):
        if self.n_GPUs <= 1 or self.cpu:
            return self.model
        else:
            return self.model.module

    def state_dict(self, **kwargs):
        target = self.get_model()
        return target.state_dict(**kwargs)

    def save(self, apath, epoch, is_best=False):
        target = self.get_model()
        torch.save(
            target.state_dict(),
            os.path.join(apath, 'model', 'model_latest.pt')
        )
        if is_best:
            torch.save(
                target.state_dict(),
                os.path.join(apath, 'model', 'model_best.pt')
            )

        if self.save_models:
            torch.save(
                target.state_dict(),
                os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
            )

    def load(self, apath, pre_train='.', resume=-1, cpu=False):
        if cpu:
            kwargs = {'map_location': lambda storage, loc: storage}
        else:
            kwargs = {}

        if resume == -1:
            self.get_model().load_state_dict(
                torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs),
                strict=True
            )

        elif resume == 0:
            if pre_train != '.':
                self.get_model().load_state_dict(
                    torch.load(pre_train, **kwargs),
                    strict=True
                )

        elif resume > 0:
            self.get_model().load_state_dict(
                torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs),
                strict=False
            )

    def forward_chop(self, x, shave=10, min_size=160000):
        scale = self.scale[self.idx_scale]
        n_GPUs = min(self.n_GPUs, 4)
        b, c, h, w = x.size()
        h_half, w_half = h // 2, w // 2
        h_size, w_size = h_half + shave, w_half + shave
        lr_list = [
            x[:, :, 0:h_size, 0:w_size],
            x[:, :, 0:h_size, (w - w_size):w],
            x[:, :, (h - h_size):h, 0:w_size],
            x[:, :, (h - h_size):h, (w - w_size):w]]

        if w_size * h_size < min_size:
            sr_list = []
            for i in range(0, 4, n_GPUs):
                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
                sr_batch = self.model(lr_batch)
                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
        else:
            sr_list = [
                self.forward_chop(patch, shave=shave, min_size=min_size) \
                for patch in lr_list
            ]

        h, w = scale * h, scale * w
        h_half, w_half = scale * h_half, scale * w_half
        h_size, w_size = scale * h_size, scale * w_size
        shave *= scale

        output = x.new(b, c, h, w)
        output[:, :, 0:h_half, 0:w_half] \
            = sr_list[0][:, :, 0:h_half, 0:w_half]
        output[:, :, 0:h_half, w_half:w] \
            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
        output[:, :, h_half:h, 0:w_half] \
            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
        output[:, :, h_half:h, w_half:w] \
            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

        return output

    def forward_x8(self, x, forward_function):
        def _transform(v, op):
            if self.precision != 'single': v = v.float()

            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            if self.precision == 'half': ret = ret.half()

            return ret

        lr_list = [x]
        for tf in 'v', 'h', 't':
            lr_list.extend([_transform(t, tf) for t in lr_list])

        sr_list = [forward_function(aug) for aug in lr_list]
        for i in range(len(sr_list)):
            if i > 3:
                sr_list[i] = _transform(sr_list[i], 't')
            if i % 4 > 1:
                sr_list[i] = _transform(sr_list[i], 'h')
            if (i % 4) % 2 == 1:
                sr_list[i] = _transform(sr_list[i], 'v')

        output_cat = torch.cat(sr_list, dim=0)
        output = output_cat.mean(dim=0, keepdim=True)

        return output



================================================
FILE: model/blindsr.py
================================================
import torch
from torch import nn
import model.common as common
import torch.nn.functional as F
from moco.builder import MoCo


def make_model(args):
    return BlindSR(args)


class DA_conv(nn.Module):
    def __init__(self, channels_in, channels_out, kernel_size, reduction):
        super(DA_conv, self).__init__()
        self.channels_out = channels_out
        self.channels_in = channels_in
        self.kernel_size = kernel_size

        self.kernel = nn.Sequential(
            nn.Linear(64, 64, bias=False),
            nn.LeakyReLU(0.1, True),
            nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False)
        )
        self.conv = common.default_conv(channels_in, channels_out, 1)
        self.ca = CA_layer(channels_in, channels_out, reduction)

        self.relu = nn.LeakyReLU(0.1, True)

    def forward(self, x):
        '''
        :param x[0]: feature map: B * C * H * W
        :param x[1]: degradation representation: B * C
        '''
        b, c, h, w = x[0].size()

        # branch 1
        kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size)
        out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2))
        out = self.conv(out.view(b, -1, h, w))

        # branch 2
        out = out + self.ca(x)

        return out


class CA_layer(nn.Module):
    def __init__(self, channels_in, channels_out, reduction):
        super(CA_layer, self).__init__()
        self.conv_du = nn.Sequential(
            nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        '''
        :param x[0]: feature map: B * C * H * W
        :param x[1]: degradation representation: B * C
        '''
        att = self.conv_du(x[1][:, :, None, None])

        return x[0] * att


class DAB(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction):
        super(DAB, self).__init__()

        self.da_conv1 = DA_conv(n_feat, n_feat, kernel_size, reduction)
        self.da_conv2 = DA_conv(n_feat, n_feat, kernel_size, reduction)
        self.conv1 = conv(n_feat, n_feat, kernel_size)
        self.conv2 = conv(n_feat, n_feat, kernel_size)

        self.relu =  nn.LeakyReLU(0.1, True)

    def forward(self, x):
        '''
        :param x[0]: feature map: B * C * H * W
        :param x[1]: degradation representation: B * C
        '''

        out = self.relu(self.da_conv1(x))
        out = self.relu(self.conv1(out))
        out = self.relu(self.da_conv2([out, x[1]]))
        out = self.conv2(out) + x[0]

        return out


class DAG(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks):
        super(DAG, self).__init__()
        self.n_blocks = n_blocks
        modules_body = [
            DAB(conv, n_feat, kernel_size, reduction) \
            for _ in range(n_blocks)
        ]
        modules_body.append(conv(n_feat, n_feat, kernel_size))

        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        '''
        :param x[0]: feature map: B * C * H * W
        :param x[1]: degradation representation: B * C
        '''
        res = x[0]
        for i in range(self.n_blocks):
            res = self.body[i]([res, x[1]])
        res = self.body[-1](res)
        res = res + x[0]

        return res


class DASR(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(DASR, self).__init__()

        self.n_groups = 5
        n_blocks = 5
        n_feats = 64
        kernel_size = 3
        reduction = 8
        scale = int(args.scale[0])

        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std)
        self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1)

        # head module
        modules_head = [conv(3, n_feats, kernel_size)]
        self.head = nn.Sequential(*modules_head)

        # compress
        self.compress = nn.Sequential(
            nn.Linear(256, 64, bias=False),
            nn.LeakyReLU(0.1, True)
        )

        # body
        modules_body = [
            DAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) \
            for _ in range(self.n_groups)
        ]
        modules_body.append(conv(n_feats, n_feats, kernel_size))
        self.body = nn.Sequential(*modules_body)

        # tail
        modules_tail = [common.Upsampler(conv, scale, n_feats, act=False),
                        conv(n_feats, 3, kernel_size)]
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x, k_v):
        k_v = self.compress(k_v)

        # sub mean
        x = self.sub_mean(x)

        # head
        x = self.head(x)

        # body
        res = x
        for i in range(self.n_groups):
            res = self.body[i]([res, k_v])
        res = self.body[-1](res)
        res = res + x

        # tail
        x = self.tail(res)

        # add mean
        x = self.add_mean(x)

        return x


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.E = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.mlp = nn.Sequential(
            nn.Linear(256, 256),
            nn.LeakyReLU(0.1, True),
            nn.Linear(256, 256),
        )

    def forward(self, x):
        fea = self.E(x).squeeze(-1).squeeze(-1)
        out = self.mlp(fea)

        return fea, out


class BlindSR(nn.Module):
    def __init__(self, args):
        super(BlindSR, self).__init__()

        # Generator
        self.G = DASR(args)

        # Encoder
        self.E = MoCo(base_encoder=Encoder)

    def forward(self, x):
        if self.training:
            x_query = x[:, 0, ...]                          # b, c, h, w
            x_key = x[:, 1, ...]                            # b, c, h, w

            # degradation-aware represenetion learning
            fea, logits, labels = self.E(x_query, x_key)

            # degradation-aware SR
            sr = self.G(x_query, fea)

            return sr, logits, labels
        else:
            # degradation-aware represenetion learning
            fea = self.E(x, x)

            # degradation-aware SR
            sr = self.G(x, fea)

            return sr


================================================
FILE: model/common.py
================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)


class MeanShift(nn.Conv2d):
    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1)
        self.weight.data.div_(std.view(3, 1, 1, 1))
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
        self.bias.data.div_(std)
        self.weight.requires_grad = False
        self.bias.requires_grad = False


class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feat, act=False, bias=True):
        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feat, 4 * n_feat, 3, bias))
                m.append(nn.PixelShuffle(2))
                if act: m.append(act())
        elif scale == 3:
            m.append(conv(n_feat, 9 * n_feat, 3, bias))
            m.append(nn.PixelShuffle(3))
            if act: m.append(act())
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)



================================================
FILE: option.py
================================================
import argparse
import template

parser = argparse.ArgumentParser(description='EDSR and MDSR')

parser.add_argument('--debug', action='store_true',
                    help='Enables debug mode')
parser.add_argument('--template', default='.',
                    help='You can set various templates in option.py')

# Hardware specifications
parser.add_argument('--n_threads', type=int, default=4,
                    help='number of threads for data loading')
parser.add_argument('--cpu', type=bool, default=False,
                    help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=2,
                    help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed')

# Data specifications
parser.add_argument('--dir_data', type=str, default='D:/LongguangWang/Data',
                    help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',
                    help='demo image directory')
parser.add_argument('--data_train', type=str, default='DF2K',
                    help='train dataset name')
parser.add_argument('--data_test', type=str, default='Set14',
                    help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-3450/801-810',
                    help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
                    help='dataset file extension')
parser.add_argument('--scale', type=str, default='4',
                    help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
                    help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
                    help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
                    help='number of color channels to use')
parser.add_argument('--chop', action='store_true',
                    help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
                    help='do not use data augmentation')

# Degradation specifications
parser.add_argument('--blur_kernel', type=int, default=21,
                    help='size of blur kernels')
parser.add_argument('--blur_type', type=str, default='iso_gaussian',
                    help='blur types (iso_gaussian | aniso_gaussian)')
parser.add_argument('--mode', type=str, default='bicubic',
                    help='downsampler (bicubic | s-fold)')
parser.add_argument('--noise', type=float, default=0.0,
                    help='noise level')
## isotropic Gaussian blur
parser.add_argument('--sig_min', type=float, default=0.2,
                    help='minimum sigma of isotropic Gaussian blurs')
parser.add_argument('--sig_max', type=float, default=4.0,
                    help='maximum sigma of isotropic Gaussian blurs')
parser.add_argument('--sig', type=float, default=4.0,
                    help='specific sigma of isotropic Gaussian blurs')
## anisotropic Gaussian blur
parser.add_argument('--lambda_min', type=float, default=0.2,
                    help='minimum value for the eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_max', type=float, default=4.0,
                    help='maximum value for the eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_1', type=float, default=0.2,
                    help='one eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_2', type=float, default=4.0,
                    help='another eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--theta', type=float, default=0.0,
                    help='rotation angle of anisotropic Gaussian blurs [0, 180]')


# Model specifications
parser.add_argument('--model', default='blindsr',
                    help='model name')
parser.add_argument('--pre_train', type=str, default= '.',
                    help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
                    help='pre-trained model directory')
parser.add_argument('--shift_mean', default=True,
                    help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
                    help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
                    choices=('single', 'half'),
                    help='FP precision for test (single | half)')

# Training specifications
parser.add_argument('--reset', action='store_true',
                    help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
                    help='do test per every N batches')
parser.add_argument('--epochs_encoder', type=int, default=100,
                    help='number of epochs to train the degradation encoder')
parser.add_argument('--epochs_sr', type=int, default=500,
                    help='number of epochs to train the whole network')
parser.add_argument('--batch_size', type=int, default=32,
                    help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
                    help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
                    help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',
                    help='set this option to test the model')

# Optimization specifications
parser.add_argument('--lr_encoder', type=float, default=1e-3,
                    help='learning rate to train the degradation encoder')
parser.add_argument('--lr_sr', type=float, default=1e-4,
                    help='learning rate to train the whole network')
parser.add_argument('--lr_decay_encoder', type=int, default=60,
                    help='learning rate decay per N epochs')
parser.add_argument('--lr_decay_sr', type=int, default=125,
                    help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step',
                    help='learning rate decay type')
parser.add_argument('--gamma_encoder', type=float, default=0.1,
                    help='learning rate decay factor for step decay')
parser.add_argument('--gamma_sr', type=float, default=0.5,
                    help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
                    choices=('SGD', 'ADAM', 'RMSprop'),
                    help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
                    help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9,
                    help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999,
                    help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8,
                    help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
                    help='weight decay')
parser.add_argument('--start_epoch', type=int, default=0,
                    help='resume from the snapshot, and the start_epoch')

# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
                    help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e6',
                    help='skipping batch that has large error')

# Log specifications
parser.add_argument('--save', type=str, default='blindsr',
                    help='file name to save')
parser.add_argument('--load', type=str, default='.',
                    help='file name to load')
parser.add_argument('--resume', type=int, default=0,
                    help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
                    help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=200,
                    help='how many batches to wait before logging training status')
parser.add_argument('--save_results', default=False,
                    help='save output results')

args = parser.parse_args()
template.set_template(args)

args.scale = list(map(lambda x: float(x), args.scale.split('+')))



================================================
FILE: quick_test.py
================================================
from model.blindsr import BlindSR
import torch
import numpy as np
import imageio
import argparse
import os
import utility
import cv2


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--img_dir', type=str, default='D:/LongguangWang/Data/test.png',
                        help='image directory')
    parser.add_argument('--scale', type=str, default='2',
                        help='super resolution scale')
    parser.add_argument('--resume', type=int, default=600,
                        help='resume from specific checkpoint')
    parser.add_argument('--blur_type', type=str, default='iso_gaussian',
                        help='blur types (iso_gaussian | aniso_gaussian)')
    return parser.parse_args()


def main():
    args = parse_args()
    if args.blur_type == 'iso_gaussian':
        dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_iso'
    elif args.blur_type == 'aniso_gaussian':
        dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_aniso'

    # path to save sr images
    save_dir = dir + '/results'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    DASR = BlindSR(args).cuda()
    DASR.load_state_dict(torch.load(dir + '/model/model_' + str(args.resume) + '.pt'), strict=False)
    DASR.eval()

    lr = imageio.imread(args.img_dir)
    lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
    lr = torch.from_numpy(lr).float().cuda().unsqueeze(0).unsqueeze(0)

    # inference
    sr = DASR(lr[:, 0, ...])
    sr = utility.quantize(sr, 255.0)

    # save sr results
    img_name = args.img_dir.split('.png')[0].split('/')[-1]
    sr = np.array(sr.squeeze(0).permute(1, 2, 0).data.cpu())
    sr = sr[:, :, [2, 1, 0]]
    cv2.imwrite(save_dir + '/' + img_name + '_sr.png', sr)


if __name__ == '__main__':
    with torch.no_grad():
        main()

================================================
FILE: quick_test.sh
================================================
# super-resolve an LR image (x2) using the model trained on noise-free degradations with isotropic Gaussian blurs
python quick_test.py --img_dir='D:/LongguangWang/Data/test.png' \
                     --scale='2' \
                     --resume=600 \
                     --blur_type='iso_gaussian'

# super-resolve an LR image (x4) using the model trained on general degradations with anisotropic Gaussian blurs and noises
python quick_test.py --img_dir='D:/LongguangWang/Data/test.png' \
                     --scale='4' \
                     --resume=600 \
                     --blur_type='aniso_gaussian'

cmd /k


================================================
FILE: template.py
================================================
def set_template(args):
    # Set the templates here
    if args.template.find('jpeg') >= 0:
        args.data_train = 'DIV2K_jpeg'
        args.data_test = 'DIV2K_jpeg'
        args.epochs = 200
        args.lr_decay = 100

    if args.template.find('EDSR_paper') >= 0:
        args.model = 'EDSR'
        args.n_resblocks = 32
        args.n_feats = 256
        args.res_scale = 0.1

    if args.template.find('MDSR') >= 0:
        args.model = 'MDSR'
        args.patch_size = 48
        args.epochs = 650

    if args.template.find('DDBPN') >= 0:
        args.model = 'DDBPN'
        args.patch_size = 128
        args.scale = '4'

        args.data_test = 'Set5'

        args.batch_size = 20
        args.epochs = 1000
        args.lr_decay = 500
        args.gamma = 0.1
        args.weight_decay = 1e-4

        args.loss = '1*MSE'

    if args.template.find('GAN') >= 0:
        args.epochs = 200
        args.lr = 5e-5
        args.lr_decay = 150

    if args.template.find('RCAN') >= 0:
        args.model = 'RCAN'
        args.n_resgroups = 10
        args.n_resblocks = 20
        args.n_feats = 64
        args.chop = True



================================================
FILE: test.py
================================================
from option import args
import torch
import utility
import data
import model
import loss
from trainer import Trainer


if __name__ == '__main__':
    torch.manual_seed(args.seed)
    checkpoint = utility.checkpoint(args)

    if checkpoint.ok:
        loader = data.Data(args)
        model = model.Model(args, checkpoint)
        loss = loss.Loss(args, checkpoint) if not args.test_only else None
        t = Trainer(args, loader, model, loss, checkpoint)
        while not t.terminate():
            t.test()

        checkpoint.done()


================================================
FILE: test.sh
================================================
# noise-free degradations with isotropic Gaussian blurs
python test.py --test_only \
               --dir_data='D:/LongguangWang/Data' \
               --data_test='Set14' \
               --model='blindsr' \
               --scale='2' \
               --resume=600 \
               --blur_type='iso_gaussian' \
               --noise=0.0 \
               --sig=1.2


# general degradations with anisotropic Gaussian blurs and noises
python test.py --test_only \
               --dir_data='D:/LongguangWang/Data' \
               --data_test='Set14' \
               --model='blindsr' \
               --scale='4' \
               --resume=600 \
               --blur_type='aniso_gaussian' \
               --noise=10.0 \
               --theta=0.0 \
               --lambda_1=0.2 \
               --lambda_2=4.0

cmd /k

================================================
FILE: trainer.py
================================================
import os
import utility
import torch
from decimal import Decimal
import torch.nn.functional as F
from utils import util


class Trainer():
    def __init__(self, args, loader, my_model, my_loss, ckp):
        self.args = args
        self.scale = args.scale

        self.ckp = ckp
        self.loader_train = loader.loader_train
        self.loader_test = loader.loader_test
        self.model = my_model
        self.model_E = torch.nn.DataParallel(self.model.get_model().E, range(self.args.n_GPUs))
        self.loss = my_loss
        self.contrast_loss = torch.nn.CrossEntropyLoss().cuda()
        self.optimizer = utility.make_optimizer(args, self.model)
        self.scheduler = utility.make_scheduler(args, self.optimizer)

        if self.args.load != '.':
            self.optimizer.load_state_dict(
                torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
            )
            for _ in range(len(ckp.log)): self.scheduler.step()

    def train(self):
        self.scheduler.step()
        self.loss.step()
        epoch = self.scheduler.last_epoch + 1

        # lr stepwise
        if epoch <= self.args.epochs_encoder:
            lr = self.args.lr_encoder * (self.args.gamma_encoder ** (epoch // self.args.lr_decay_encoder))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        else:
            lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

        self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
        self.loss.start_log()
        self.model.train()

        degrade = util.SRMDPreprocessing(
            self.scale[0],
            kernel_size=self.args.blur_kernel,
            blur_type=self.args.blur_type,
            sig_min=self.args.sig_min,
            sig_max=self.args.sig_max,
            lambda_min=self.args.lambda_min,
            lambda_max=self.args.lambda_max,
            noise=self.args.noise
        )

        timer = utility.timer()
        losses_contrast, losses_sr = utility.AverageMeter(), utility.AverageMeter()

        for batch, (hr, _, idx_scale) in enumerate(self.loader_train):
            hr = hr.cuda()                              # b, n, c, h, w
            lr, b_kernels = degrade(hr)                 # bn, c, h, w

            self.optimizer.zero_grad()

            timer.tic()
            # forward
            ## train degradation encoder
            if epoch <= self.args.epochs_encoder:
                _, output, target = self.model_E(im_q=lr[:,0,...], im_k=lr[:,1,...])
                loss_constrast = self.contrast_loss(output, target)
                loss = loss_constrast

                losses_contrast.update(loss_constrast.item())
            ## train the whole network
            else:
                sr, output, target = self.model(lr)
                loss_SR = self.loss(sr, hr[:,0,...])
                loss_constrast = self.contrast_loss(output, target)
                loss = loss_constrast + loss_SR

                losses_sr.update(loss_SR.item())
                losses_contrast.update(loss_constrast.item())

            # backward
            loss.backward()
            self.optimizer.step()
            timer.hold()

            if epoch <= self.args.epochs_encoder:
                if (batch + 1) % self.args.print_every == 0:
                    self.ckp.write_log(
                        'Epoch: [{:03d}][{:04d}/{:04d}]\t'
                        'Loss [contrastive loss: {:.3f}]\t'
                        'Time [{:.1f}s]'.format(
                            epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
                            losses_contrast.avg,
                            timer.release()
                        ))
            else:
                if (batch + 1) % self.args.print_every == 0:
                    self.ckp.write_log(
                        'Epoch: [{:04d}][{:04d}/{:04d}]\t'
                        'Loss [SR loss:{:.3f} | contrastive loss: {:.3f}]\t'
                        'Time [{:.1f}s]'.format(
                            epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
                            losses_sr.avg, losses_contrast.avg,
                            timer.release(),
                        ))

        self.loss.end_log(len(self.loader_train))

        # save model
        target = self.model.get_model()
        model_dict = target.state_dict()
        keys = list(model_dict.keys())
        for key in keys:
            if 'E.encoder_k' in key or 'queue' in key:
                del model_dict[key]
        torch.save(
            model_dict,
            os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch))
        )

    def test(self):
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)))
        self.model.eval()

        timer_test = utility.timer()

        with torch.no_grad():
            for idx_scale, scale in enumerate(self.scale):
                self.loader_test.dataset.set_scale(idx_scale)
                eval_psnr = 0
                eval_ssim = 0

                degrade = util.SRMDPreprocessing(
                    self.scale[0],
                    kernel_size=self.args.blur_kernel,
                    blur_type=self.args.blur_type,
                    sig=self.args.sig,
                    lambda_1=self.args.lambda_1,
                    lambda_2=self.args.lambda_2,
                    theta=self.args.theta,
                    noise=self.args.noise
                )

                for idx_img, (hr, filename, _) in enumerate(self.loader_test):
                    hr = hr.cuda()                      # b, 1, c, h, w
                    hr = self.crop_border(hr, scale)
                    lr, _ = degrade(hr, random=False)   # b, 1, c, h, w
                    hr = hr[:, 0, ...]                  # b, c, h, w

                    # inference
                    timer_test.tic()
                    sr = self.model(lr[:, 0, ...])
                    timer_test.hold()

                    sr = utility.quantize(sr, self.args.rgb_range)
                    hr = utility.quantize(hr, self.args.rgb_range)

                    # metrics
                    eval_psnr += utility.calc_psnr(
                        sr, hr, scale, self.args.rgb_range,
                        benchmark=self.loader_test.dataset.benchmark
                    )
                    eval_ssim += utility.calc_ssim(
                        sr, hr, scale,
                        benchmark=self.loader_test.dataset.benchmark
                    )

                    # save results
                    if self.args.save_results:
                        save_list = [sr]
                        filename = filename[0]
                        self.ckp.save_results(filename, save_list, scale)

                self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
                self.ckp.write_log(
                    '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f}'.format(
                        self.args.resume,
                        self.args.data_test,
                        scale,
                        eval_psnr / len(self.loader_test),
                        eval_ssim / len(self.loader_test),
                    ))

    def crop_border(self, img_hr, scale):
        b, n, c, h, w = img_hr.size()

        img_hr = img_hr[:, :, :, :int(h//scale*scale), :int(w//scale*scale)]

        return img_hr

    def terminate(self):
        if self.args.test_only:
            self.test()
            return True
        else:
            epoch = self.scheduler.last_epoch + 1
            return epoch >= self.args.epochs_encoder + self.args.epochs_sr



================================================
FILE: utility.py
================================================
import os
import math
import time
import datetime
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc
import cv2
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class timer():
    def __init__(self):
        self.acc = 0
        self.tic()

    def tic(self):
        self.t0 = time.time()

    def toc(self):
        return time.time() - self.t0

    def hold(self):
        self.acc += self.toc()

    def release(self):
        ret = self.acc
        self.acc = 0

        return ret

    def reset(self):
        self.acc = 0


class checkpoint():
    def __init__(self, args):
        self.args = args
        self.ok = True
        self.log = torch.Tensor()
        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')

        if args.blur_type == 'iso_gaussian':
            self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso'
        elif args.blur_type == 'aniso_gaussian':
            self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso'

        def _make_dir(path):
            if not os.path.exists(path): os.makedirs(path)

        _make_dir(self.dir)
        _make_dir(self.dir + '/model')
        _make_dir(self.dir + '/results')

        open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
        self.log_file = open(self.dir + '/log.txt', open_type)
        with open(self.dir + '/config.txt', open_type) as f:
            f.write(now + '\n\n')
            for arg in vars(args):
                f.write('{}: {}\n'.format(arg, getattr(args, arg)))
            f.write('\n')

    def save(self, trainer, epoch, is_best=False):
        trainer.model.save(self.dir, epoch, is_best=is_best)
        trainer.loss.save(self.dir)
        trainer.loss.plot_loss(self.dir, epoch)

        self.plot_psnr(epoch)
        torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
        torch.save(
            trainer.optimizer.state_dict(),
            os.path.join(self.dir, 'optimizer.pt')
        )

    def add_log(self, log):
        self.log = torch.cat([self.log, log])

    def write_log(self, log, refresh=False):
        print(log)
        self.log_file.write(log + '\n')
        if refresh:
            self.log_file.close()
            self.log_file = open(self.dir + '/log.txt', 'a')

    def done(self):
        self.log_file.close()

    def plot_psnr(self, epoch):
        axis = np.linspace(1, epoch, epoch)
        label = 'SR on {}'.format(self.args.data_test)
        fig = plt.figure()
        plt.title(label)
        for idx_scale, scale in enumerate(self.args.scale):
            plt.plot(
                axis,
                self.log[:, idx_scale].numpy(),
                label='Scale {}'.format(scale)
            )
        plt.legend()
        plt.xlabel('Epochs')
        plt.ylabel('PSNR')
        plt.grid(True)
        plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
        plt.close(fig)

    def save_results(self, filename, save_list, scale):
        filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)

        normalized = save_list[0][0].data.mul(255 / self.args.rgb_range)
        ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
        misc.imsave('{}{}.png'.format(filename, 'SR'), ndarr)


def quantize(img, rgb_range):
    pixel_range = 255 / rgb_range
    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)


def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
    diff = (sr - hr).data.div(rgb_range)
    if benchmark:
        shave = scale
        if diff.size(1) > 1:
            convert = diff.new(1, 3, 1, 1)
            convert[0, 0, 0, 0] = 65.738
            convert[0, 1, 0, 0] = 129.057
            convert[0, 2, 0, 0] = 25.064
            diff.mul_(convert).div_(256)
            diff = diff.sum(dim=1, keepdim=True)
    else:
        shave = scale + 6
    import math
    shave = math.ceil(shave)
    valid = diff[:, :, shave:-shave, shave:-shave]
    mse = valid.pow(2).mean()

    return -10 * math.log10(mse)


def calc_ssim(img1, img2, scale=2, benchmark=False):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    if benchmark:
        border = math.ceil(scale)
    else:
        border = math.ceil(scale) + 6

    img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy()
    img1 = np.transpose(img1, (1, 2, 0))
    img2 = img2.data.squeeze().cpu().numpy()
    img2 = np.transpose(img2, (1, 2, 0))

    img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0
    img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    h, w = img1.shape[:2]
    img1_y = img1_y[border:h - border, border:w - border]
    img2_y = img2_y[border:h - border, border:w - border]

    if img1_y.ndim == 2:
        return ssim(img1_y, img2_y)
    elif img1.ndim == 3:
        if img1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(img1, img2))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')


def ssim(img1, img2):
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()


def make_optimizer(args, my_model):
    trainable = filter(lambda x: x.requires_grad, my_model.parameters())

    if args.optimizer == 'SGD':
        optimizer_function = optim.SGD
        kwargs = {'momentum': args.momentum}
    elif args.optimizer == 'ADAM':
        optimizer_function = optim.Adam
        kwargs = {
            'betas': (args.beta1, args.beta2),
            'eps': args.epsilon
        }
    elif args.optimizer == 'RMSprop':
        optimizer_function = optim.RMSprop
        kwargs = {'eps': args.epsilon}

    kwargs['weight_decay'] = args.weight_decay

    return optimizer_function(trainable, **kwargs)


def make_scheduler(args, my_optimizer):
    if args.decay_type == 'step':
        scheduler = lrs.StepLR(
            my_optimizer,
            step_size=args.lr_decay_sr,
            gamma=args.gamma_sr,
        )
    elif args.decay_type.find('step') >= 0:
        milestones = args.decay_type.split('_')
        milestones.pop(0)
        milestones = list(map(lambda x: int(x), milestones))
        scheduler = lrs.MultiStepLR(
            my_optimizer,
            milestones=milestones,
            gamma=args.gamma
        )

    scheduler.step(args.start_epoch - 1)

    return scheduler



================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/util.py
================================================
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def cal_sigma(sig_x, sig_y, radians):
    sig_x = sig_x.view(-1, 1, 1)
    sig_y = sig_y.view(-1, 1, 1)
    radians = radians.view(-1, 1, 1)

    D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1)
    U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2),
                   torch.cat([radians.sin(), radians.cos()], 2)], 1)
    sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2)))

    return sigma


def anisotropic_gaussian_kernel(batch, kernel_size, covar):
    ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2

    xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    xy = torch.stack([xx, yy], -1).view(batch, -1, 2)

    inverse_sigma = torch.inverse(covar)
    kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size)

    return kernel / kernel.sum([1, 2], keepdim=True)


def isotropic_gaussian_kernel(batch, kernel_size, sigma):
    ax = torch.arange(kernel_size).float().cuda() - kernel_size//2
    xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
    kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2))

    return kernel / kernel.sum([1,2], keepdim=True)


def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0):
    theta = torch.rand(batch).cuda() * math.pi
    lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
    lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min

    covar = cal_sigma(lambda_1, lambda_2, theta)
    kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar)
    return kernel


def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0):
    theta = torch.ones(1).cuda() * theta / 180 * math.pi
    lambda_1 = torch.ones(1).cuda() * lambda_1
    lambda_2 = torch.ones(1).cuda() * lambda_2

    covar = cal_sigma(lambda_1, lambda_2, theta)
    kernel = anisotropic_gaussian_kernel(1, kernel_size, covar)
    return kernel


def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0):
    x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min
    k = isotropic_gaussian_kernel(batch, kernel_size, x)
    return k


def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
    x = torch.ones(1).cuda() * sig
    k = isotropic_gaussian_kernel(1, kernel_size, x)
    return k


def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0):
    if blur_type == 'iso_gaussian':
        return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max)
    elif blur_type == 'aniso_gaussian':
        return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max)


def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0):
    if blur_type == 'iso_gaussian':
        return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig)
    elif blur_type == 'aniso_gaussian':
        return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta)


# implementation of matlab bicubic interpolation in pytorch
class bicubic(nn.Module):
    def __init__(self):
        super(bicubic, self).__init__()

    def cubic(self, x):
        absx = torch.abs(x)
        absx2 = torch.abs(x) * torch.abs(x)
        absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)

        condition1 = (absx <= 1).to(torch.float32)
        condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)

        f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
        return f

    def contribute(self, in_size, out_size, scale):
        kernel_width = 4
        if scale < 1:
            kernel_width = 4 / scale
        x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda()
        x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda()

        u0 = x0 / scale + 0.5 * (1 - 1 / scale)
        u1 = x1 / scale + 0.5 * (1 - 1 / scale)

        left0 = torch.floor(u0 - kernel_width / 2)
        left1 = torch.floor(u1 - kernel_width / 2)

        P = np.ceil(kernel_width) + 2

        indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
        indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()

        mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
        mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)

        if scale < 1:
            weight0 = scale * self.cubic(mid0 * scale)
            weight1 = scale * self.cubic(mid1 * scale)
        else:
            weight0 = self.cubic(mid0)
            weight1 = self.cubic(mid1)

        weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
        weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))

        indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
        indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)

        kill0 = torch.eq(weight0, 0)[0][0]
        kill1 = torch.eq(weight1, 0)[0][0]

        weight0 = weight0[:, :, kill0 == 0]
        weight1 = weight1[:, :, kill1 == 0]

        indice0 = indice0[:, :, kill0 == 0]
        indice1 = indice1[:, :, kill1 == 0]

        return weight0, weight1, indice0, indice1

    def forward(self, input, scale=1/4):
        b, c, h, w = input.shape

        weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
        weight0 = weight0[0]
        weight1 = weight1[0]

        indice0 = indice0[0].long()
        indice1 = indice1[0].long()

        out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
        out = (torch.sum(out, dim=3))
        A = out.permute(0, 1, 3, 2)

        out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
        out = out.sum(3).permute(0, 1, 3, 2)

        return out


class Gaussin_Kernel(object):
    def __init__(self, kernel_size=21, blur_type='iso_gaussian',
                 sig=2.6, sig_min=0.2, sig_max=4.0,
                 lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0):
        self.kernel_size = kernel_size
        self.blur_type = blur_type

        self.sig = sig
        self.sig_min = sig_min
        self.sig_max = sig_max

        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.theta = theta
        self.lambda_min = lambda_min
        self.lambda_max = lambda_max

    def __call__(self, batch, random):
        # random kernel
        if random == True:
            return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type,
                                          sig_min=self.sig_min, sig_max=self.sig_max,
                                          lambda_min=self.lambda_min, lambda_max=self.lambda_max)

        # stable kernel
        else:
            return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type,
                                          sig=self.sig,
                                          lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta)

class BatchBlur(nn.Module):
    def __init__(self, kernel_size=21):
        super(BatchBlur, self).__init__()
        self.kernel_size = kernel_size
        if kernel_size % 2 == 1:
            self.pad = nn.ReflectionPad2d(kernel_size//2)
        else:
            self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1))

    def forward(self, input, kernel):
        B, C, H, W = input.size()
        input_pad = self.pad(input)
        H_p, W_p = input_pad.size()[-2:]

        if len(kernel.size()) == 2:
            input_CBHW = input_pad.view((C * B, 1, H_p, W_p))
            kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size))

            return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W))
        else:
            input_CBHW = input_pad.view((1, C * B, H_p, W_p))
            kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size))
            kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size))

            return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W))


class SRMDPreprocessing(object):
    def __init__(self,
                 scale,
                 mode='bicubic',
                 kernel_size=21,
                 blur_type='iso_gaussian',
                 sig=2.6,
                 sig_min=0.2,
                 sig_max=4.0,
                 lambda_1=0.2,
                 lambda_2=4.0,
                 theta=0,
                 lambda_min=0.2,
                 lambda_max=4.0,
                 noise=0.0
                 ):
        '''
        # sig, sig_min and sig_max are used for isotropic Gaussian blurs
        During training phase (random=True):
            the width of the blur kernel is randomly selected from [sig_min, sig_max]
        During test phase (random=False):
            the width of the blur kernel is set to sig

        # lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
        During training phase (random=True):
            the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
            the angle value is randomly selected from [0, pi]
        During test phase (random=False):
            the eigenvalues of the covariance are set to lambda_1 and lambda_2
            the angle value is set to theta
        '''
        self.kernel_size = kernel_size
        self.scale = scale
        self.mode = mode
        self.noise = noise

        self.gen_kernel = Gaussin_Kernel(
            kernel_size=kernel_size, blur_type=blur_type,
            sig=sig, sig_min=sig_min, sig_max=sig_max,
            lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max
        )
        self.blur = BatchBlur(kernel_size=kernel_size)
        self.bicubic = bicubic()

    def __call__(self, hr_tensor, random=True):
        with torch.no_grad():
            # only downsampling
            if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0:
                B, N, C, H, W = hr_tensor.size()
                hr_blured = hr_tensor.view(-1, C, H, W)
                b_kernels = None

            # gaussian blur + downsampling
            else:
                B, N, C, H, W = hr_tensor.size()
                b_kernels = self.gen_kernel(B, random)  # B degradations

                # blur
                hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels)
                hr_blured = hr_blured.view(-1, C, H, W)  # BN, C, H, W

            # downsampling
            if self.mode == 'bicubic':
                lr_blured = self.bicubic(hr_blured, scale=1/self.scale)
            elif self.mode == 's-fold':
                lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0]


            # add noise
            if self.noise > 0:
                _, C, H_lr, W_lr = lr_blured.size()
                noise_level = torch.rand(B, 1, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise
                noise = torch.randn_like(lr_blured).view(-1, N, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr)
                lr_blured.add_(noise)

            lr_blured = torch.clamp(lr_blured.round(), 0, 255)


            return lr_blured.view(B, N, C, H//int(self.scale), W//int(self.scale)), b_kernels

Download .txt
gitextract_01sw5xtg/

├── LICENSE
├── README.md
├── data/
│   ├── __init__.py
│   ├── benchmark.py
│   ├── common.py
│   ├── df2k.py
│   └── multiscalesrdata.py
├── dataloader.py
├── experiment/
│   ├── blindsr_x2_bicubic_iso/
│   │   └── model/
│   │       └── model_600.pt
│   ├── blindsr_x3_bicubic_iso/
│   │   └── model/
│   │       └── model_600.pt
│   ├── blindsr_x4_bicubic_aniso/
│   │   └── model/
│   │       └── model_600.pt
│   └── blindsr_x4_bicubic_iso/
│       └── model/
│           └── model_600.pt
├── loss/
│   ├── __init__.py
│   ├── adversarial.py
│   ├── discriminator.py
│   └── vgg.py
├── main.py
├── main.sh
├── moco/
│   ├── __init__.py
│   └── builder.py
├── model/
│   ├── __init__.py
│   ├── blindsr.py
│   └── common.py
├── option.py
├── quick_test.py
├── quick_test.sh
├── template.py
├── test.py
├── test.sh
├── trainer.py
├── utility.py
└── utils/
    ├── __init__.py
    └── util.py
Download .txt
SYMBOL INDEX (154 symbols across 19 files)

FILE: data/__init__.py
  class Data (line 4) | class Data:
    method __init__ (line 5) | def __init__(self, args):

FILE: data/benchmark.py
  class Benchmark (line 6) | class Benchmark(srdata.SRData):
    method __init__ (line 7) | def __init__(self, args, name='', train=True):
    method _set_filesystem (line 12) | def _set_filesystem(self, dir_data):

FILE: data/common.py
  function get_patch (line 7) | def get_patch(img, patch_size=48, scale=1):
  function set_channel (line 18) | def set_channel(img, n_channels=3):
  function np2Tensor (line 31) | def np2Tensor(img, rgb_range=255):
  function augment (line 39) | def augment(img, hflip=True, rot=True):

FILE: data/df2k.py
  class DF2K (line 5) | class DF2K(multiscalesrdata.SRData):
    method __init__ (line 6) | def __init__(self, args, name='DF2K', train=True, benchmark=False):
    method _scan (line 9) | def _scan(self):
    method _set_filesystem (line 15) | def _set_filesystem(self, dir_data):

FILE: data/multiscalesrdata.py
  class SRData (line 13) | class SRData(data.Dataset):
    method __init__ (line 14) | def __init__(self, args, name='', train=True, benchmark=False):
    method _scan (line 68) | def _scan(self):
    method _set_filesystem (line 76) | def _set_filesystem(self, dir_data):
    method _name_hrbin (line 82) | def _name_hrbin(self):
    method _name_lrbin (line 89) | def _name_lrbin(self, scale):
    method _check_and_load (line 96) | def _check_and_load(self, ext, l, f, verbose=True, load=True):
    method __getitem__ (line 119) | def __getitem__(self, idx):
    method __len__ (line 128) | def __len__(self):
    method _get_index (line 134) | def _get_index(self, idx):
    method _load_file (line 140) | def _load_file(self, idx):
    method get_patch (line 157) | def get_patch(self, hr):
    method set_scale (line 174) | def set_scale(self, idx_scale):

FILE: dataloader.py
  function _ms_loop (line 20) | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, ...
  class _MSDataLoaderIter (line 46) | class _MSDataLoaderIter(_DataLoaderIter):
    method __init__ (line 47) | def __init__(self, loader):
  class MSDataLoader (line 117) | class MSDataLoader(DataLoader):
    method __init__ (line 118) | def __init__(
    method __iter__ (line 133) | def __iter__(self):

FILE: loss/__init__.py
  class Loss (line 13) | class Loss(nn.modules.loss._Loss):
    method __init__ (line 14) | def __init__(self, args, ckp):
    method forward (line 70) | def forward(self, sr, hr):
    method step (line 87) | def step(self):
    method start_log (line 92) | def start_log(self):
    method end_log (line 95) | def end_log(self, n_batches):
    method display_loss (line 98) | def display_loss(self, batch):
    method plot_loss (line 106) | def plot_loss(self, apath, epoch):
    method get_loss_module (line 120) | def get_loss_module(self):
    method save (line 126) | def save(self, apath):
    method load (line 130) | def load(self, apath, cpu=False):

FILE: loss/adversarial.py
  class Adversarial (line 11) | class Adversarial(nn.Module):
    method __init__ (line 12) | def __init__(self, args, gan_type):
    method forward (line 26) | def forward(self, fake, real):
    method state_dict (line 78) | def state_dict(self, *args, **kwargs):

FILE: loss/discriminator.py
  class Discriminator (line 5) | class Discriminator(nn.Module):
    method __init__ (line 6) | def __init__(self, args, gan_type='GAN'):
    method forward (line 40) | def forward(self, x):

FILE: loss/vgg.py
  class VGG (line 9) | class VGG(nn.Module):
    method __init__ (line 10) | def __init__(self, conv_index, rgb_range=1):
    method forward (line 24) | def forward(self, sr, hr):

FILE: moco/builder.py
  class MoCo (line 6) | class MoCo(nn.Module):
    method __init__ (line 11) | def __init__(self, base_encoder, dim=256, K=32*256, m=0.999, T=0.07, m...
    method _momentum_update_key_encoder (line 40) | def _momentum_update_key_encoder(self):
    method _dequeue_and_enqueue (line 48) | def _dequeue_and_enqueue(self, keys):
    method _batch_shuffle_ddp (line 63) | def _batch_shuffle_ddp(self, x):
    method _batch_unshuffle_ddp (line 91) | def _batch_unshuffle_ddp(self, x, idx_unshuffle):
    method forward (line 109) | def forward(self, im_q, im_k):
  function concat_all_gather (line 157) | def concat_all_gather(tensor):

FILE: model/__init__.py
  class Model (line 8) | class Model(nn.Module):
    method __init__ (line 9) | def __init__(self, args, ckp):
    method forward (line 38) | def forward(self, x):
    method get_model (line 51) | def get_model(self):
    method state_dict (line 57) | def state_dict(self, **kwargs):
    method save (line 61) | def save(self, apath, epoch, is_best=False):
    method load (line 79) | def load(self, apath, pre_train='.', resume=-1, cpu=False):
    method forward_chop (line 104) | def forward_chop(self, x, shave=10, min_size=160000):
    method forward_x8 (line 145) | def forward_x8(self, x, forward_function):

FILE: model/blindsr.py
  function make_model (line 8) | def make_model(args):
  class DA_conv (line 12) | class DA_conv(nn.Module):
    method __init__ (line 13) | def __init__(self, channels_in, channels_out, kernel_size, reduction):
    method forward (line 29) | def forward(self, x):
  class CA_layer (line 47) | class CA_layer(nn.Module):
    method __init__ (line 48) | def __init__(self, channels_in, channels_out, reduction):
    method forward (line 57) | def forward(self, x):
  class DAB (line 67) | class DAB(nn.Module):
    method __init__ (line 68) | def __init__(self, conv, n_feat, kernel_size, reduction):
    method forward (line 78) | def forward(self, x):
  class DAG (line 92) | class DAG(nn.Module):
    method __init__ (line 93) | def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks):
    method forward (line 104) | def forward(self, x):
  class DASR (line 118) | class DASR(nn.Module):
    method __init__ (line 119) | def __init__(self, args, conv=common.default_conv):
    method forward (line 158) | def forward(self, x, k_v):
  class Encoder (line 183) | class Encoder(nn.Module):
    method __init__ (line 184) | def __init__(self):
    method forward (line 214) | def forward(self, x):
  class BlindSR (line 221) | class BlindSR(nn.Module):
    method __init__ (line 222) | def __init__(self, args):
    method forward (line 231) | def forward(self, x):

FILE: model/common.py
  function default_conv (line 7) | def default_conv(in_channels, out_channels, kernel_size, bias=True):
  class MeanShift (line 11) | class MeanShift(nn.Conv2d):
    method __init__ (line 12) | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
  class Upsampler (line 23) | class Upsampler(nn.Sequential):
    method __init__ (line 24) | def __init__(self, conv, scale, n_feat, act=False, bias=True):

FILE: quick_test.py
  function parse_args (line 11) | def parse_args():
  function main (line 24) | def main():

FILE: template.py
  function set_template (line 1) | def set_template(args):

FILE: trainer.py
  class Trainer (line 9) | class Trainer():
    method __init__ (line 10) | def __init__(self, args, loader, my_model, my_loss, ckp):
    method train (line 30) | def train(self):
    method test (line 128) | def test(self):
    method crop_border (line 192) | def crop_border(self, img_hr, scale):
    method terminate (line 199) | def terminate(self):

FILE: utility.py
  class AverageMeter (line 14) | class AverageMeter(object):
    method __init__ (line 16) | def __init__(self):
    method reset (line 19) | def reset(self):
    method update (line 25) | def update(self, val, n=1):
  class timer (line 32) | class timer():
    method __init__ (line 33) | def __init__(self):
    method tic (line 37) | def tic(self):
    method toc (line 40) | def toc(self):
    method hold (line 43) | def hold(self):
    method release (line 46) | def release(self):
    method reset (line 52) | def reset(self):
  class checkpoint (line 56) | class checkpoint():
    method __init__ (line 57) | def __init__(self, args):
    method save (line 83) | def save(self, trainer, epoch, is_best=False):
    method add_log (line 95) | def add_log(self, log):
    method write_log (line 98) | def write_log(self, log, refresh=False):
    method done (line 105) | def done(self):
    method plot_psnr (line 108) | def plot_psnr(self, epoch):
    method save_results (line 126) | def save_results(self, filename, save_list, scale):
  function quantize (line 134) | def quantize(img, rgb_range):
  function calc_psnr (line 139) | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
  function calc_ssim (line 160) | def calc_ssim(img1, img2, scale=2, benchmark=False):
  function ssim (line 197) | def ssim(img1, img2):
  function make_optimizer (line 220) | def make_optimizer(args, my_model):
  function make_scheduler (line 241) | def make_scheduler(args, my_optimizer):

FILE: utils/util.py
  function cal_sigma (line 8) | def cal_sigma(sig_x, sig_y, radians):
  function anisotropic_gaussian_kernel (line 21) | def anisotropic_gaussian_kernel(batch, kernel_size, covar):
  function isotropic_gaussian_kernel (line 34) | def isotropic_gaussian_kernel(batch, kernel_size, sigma):
  function random_anisotropic_gaussian_kernel (line 43) | def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_m...
  function stable_anisotropic_gaussian_kernel (line 53) | def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1...
  function random_isotropic_gaussian_kernel (line 63) | def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0....
  function stable_isotropic_gaussian_kernel (line 69) | def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
  function random_gaussian_kernel (line 75) | def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussia...
  function stable_gaussian_kernel (line 82) | def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig...
  class bicubic (line 90) | class bicubic(nn.Module):
    method __init__ (line 91) | def __init__(self):
    method cubic (line 94) | def cubic(self, x):
    method contribute (line 105) | def contribute(self, in_size, out_size, scale):
    method forward (line 150) | def forward(self, input, scale=1/4):
  class Gaussin_Kernel (line 170) | class Gaussin_Kernel(object):
    method __init__ (line 171) | def __init__(self, kernel_size=21, blur_type='iso_gaussian',
    method __call__ (line 187) | def __call__(self, batch, random):
  class BatchBlur (line 200) | class BatchBlur(nn.Module):
    method __init__ (line 201) | def __init__(self, kernel_size=21):
    method forward (line 209) | def forward(self, input, kernel):
  class SRMDPreprocessing (line 227) | class SRMDPreprocessing(object):
    method __init__ (line 228) | def __init__(self,
    method __call__ (line 271) | def __call__(self, hr_tensor, random=True):
Condensed preview — 33 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (95K chars).
[
  {
    "path": "LICENSE",
    "chars": 1095,
    "preview": "MIT License\n\nCopyright (c) 2022 The Learning and Vision Atelier (LAVA)\n\nPermission is hereby granted, free of charge, to"
  },
  {
    "path": "README.md",
    "chars": 2988,
    "preview": "# DASR\nPytorch implementation of \"Unsupervised Degradation Representation Learning for Blind Super-Resolution\", CVPR 202"
  },
  {
    "path": "data/__init__.py",
    "chars": 1240,
    "preview": "from importlib import import_module\nfrom dataloader import MSDataLoader\n\nclass Data:\n    def __init__(self, args):\n     "
  },
  {
    "path": "data/benchmark.py",
    "chars": 583,
    "preview": "import os\nfrom data import common\nfrom data import multiscalesrdata as srdata\n\n\nclass Benchmark(srdata.SRData):\n    def "
  },
  {
    "path": "data/common.py",
    "chars": 1149,
    "preview": "import random\nimport numpy as np\nimport skimage.color as sc\nimport torch\n\n\ndef get_patch(img, patch_size=48, scale=1):\n "
  },
  {
    "path": "data/df2k.py",
    "chars": 595,
    "preview": "import os\nfrom data import multiscalesrdata\n\n\nclass DF2K(multiscalesrdata.SRData):\n    def __init__(self, args, name='DF"
  },
  {
    "path": "data/multiscalesrdata.py",
    "chars": 5636,
    "preview": "import os\nimport glob\n\nfrom data import common\nimport pickle\nimport numpy as np\nimport imageio\n\nimport torch\nimport torc"
  },
  {
    "path": "dataloader.py",
    "chars": 4776,
    "preview": "import sys\nimport threading\nimport queue\nimport random\nimport collections\n\nimport torch\nimport torch.multiprocessing as "
  },
  {
    "path": "loss/__init__.py",
    "chars": 4703,
    "preview": "import os\nfrom importlib import import_module\n\nimport matplotlib.pyplot as plt\n\nimport numpy as np\n\nimport torch\nimport "
  },
  {
    "path": "loss/adversarial.py",
    "chars": 3320,
    "preview": "import utility\nfrom model import common\nfrom loss import discriminator\n\nimport torch\nimport torch.nn as nn\nimport torch."
  },
  {
    "path": "loss/discriminator.py",
    "chars": 1287,
    "preview": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    def __init__(self, args, gan_type='"
  },
  {
    "path": "loss/vgg.py",
    "chars": 1093,
    "preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
  },
  {
    "path": "main.py",
    "chars": 538,
    "preview": "from option import args\nimport torch\nimport utility\nimport data\nimport model\nimport loss\nfrom trainer import Trainer\n\n\ni"
  },
  {
    "path": "main.sh",
    "chars": 639,
    "preview": "# noise-free degradations with isotropic Gaussian blurs\npython main.py --dir_data='D:/LongguangWang/Data' \\\n            "
  },
  {
    "path": "moco/__init__.py",
    "chars": 71,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\n"
  },
  {
    "path": "moco/builder.py",
    "chars": 5420,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved\nimport torch\nimport torch.nn as nn\n\n\nclass MoCo(n"
  },
  {
    "path": "model/__init__.py",
    "chars": 5855,
    "preview": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\n\n\nclass Model(nn.Module):\n    def __in"
  },
  {
    "path": "model/blindsr.py",
    "chars": 7259,
    "preview": "import torch\nfrom torch import nn\nimport model.common as common\nimport torch.nn.functional as F\nfrom moco.builder import"
  },
  {
    "path": "model/common.py",
    "chars": 1359,
    "preview": "import math\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef default_conv(in_channels, out_chann"
  },
  {
    "path": "option.py",
    "chars": 8299,
    "preview": "import argparse\nimport template\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--d"
  },
  {
    "path": "quick_test.py",
    "chars": 1859,
    "preview": "from model.blindsr import BlindSR\nimport torch\nimport numpy as np\nimport imageio\nimport argparse\nimport os\nimport utilit"
  },
  {
    "path": "quick_test.sh",
    "chars": 619,
    "preview": "# super-resolve an LR image (x2) using the model trained on noise-free degradations with isotropic Gaussian blurs\npython"
  },
  {
    "path": "template.py",
    "chars": 1138,
    "preview": "def set_template(args):\n    # Set the templates here\n    if args.template.find('jpeg') >= 0:\n        args.data_train = '"
  },
  {
    "path": "test.py",
    "chars": 538,
    "preview": "from option import args\nimport torch\nimport utility\nimport data\nimport model\nimport loss\nfrom trainer import Trainer\n\n\ni"
  },
  {
    "path": "test.sh",
    "chars": 820,
    "preview": "# noise-free degradations with isotropic Gaussian blurs\npython test.py --test_only \\\n               --dir_data='D:/Longg"
  },
  {
    "path": "trainer.py",
    "chars": 7910,
    "preview": "import os\nimport utility\nimport torch\nfrom decimal import Decimal\nimport torch.nn.functional as F\nfrom utils import util"
  },
  {
    "path": "utility.py",
    "chars": 7852,
    "preview": "import os\nimport math\nimport time\nimport datetime\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport scipy.misc a"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/util.py",
    "chars": 12328,
    "preview": "import math\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef cal_sigma(sig_x,"
  }
]

// ... and 4 more files (download for full content)

About this extraction

This page contains the full source code of the The-Learning-And-Vision-Atelier-LAVA/DASR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 33 files (91.3 MB), approximately 23.5k tokens, and a symbol index with 154 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!