Full Code of wwlCape/HAN for AI

master 0595e23a9925 cached
47 files
341.4 KB
90.5k tokens
461 symbols
1 requests
Download .txt
Showing preview only (358K chars total). Download the full file or copy to clipboard to get everything.
Repository: wwlCape/HAN
Branch: master
Commit: 0595e23a9925
Files: 47
Total size: 341.4 KB

Directory structure:
gitextract_zj0t_fxr/

├── .gitignore
├── LICENSE
├── README.md
├── experiment/
│   └── .gitignore
└── src/
    ├── __init__.py
    ├── data/
    │   ├── __init__.py
    │   ├── benchmark.py
    │   ├── common.py
    │   ├── demo.py
    │   ├── div2k.py
    │   ├── div2kjpeg.py
    │   ├── sr291.py
    │   ├── srdata.py
    │   └── video.py
    ├── dataloader.py
    ├── demo.sh
    ├── loss/
    │   ├── __init__.py
    │   ├── adversarial.py
    │   ├── discriminator.py
    │   └── vgg.py
    ├── main.py
    ├── model/
    │   ├── __init__.py
    │   ├── common.py
    │   ├── dcn/
    │   │   ├── __init__.py
    │   │   ├── deform_conv.py
    │   │   ├── setup.py
    │   │   └── src/
    │   │       ├── deform_conv_cuda.cpp
    │   │       └── deform_conv_cuda_kernel.cu
    │   ├── ddbpn.py
    │   ├── edsr.py
    │   ├── han.py
    │   ├── matrixmodel.py
    │   ├── mdsr.py
    │   ├── ops.py
    │   ├── rcan.py
    │   ├── rcan1.py
    │   ├── rcan3.py
    │   ├── rcan4.py
    │   ├── rdn.py
    │   ├── rdn1.py
    │   ├── rdn2.py
    │   └── vdsr.py
    ├── option.py
    ├── template.py
    ├── trainer.py
    ├── utility.py
    └── videotester.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
## HAN

> PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network"
>
> This repository is for HAN introduced in the following paper
>
> Ben Niu, Weilei Wen, Wenqi Ren, Xiangde Zhang, Lianping Yang, Shuzhen Wang, Kaihao Zhang, Xiaochun Cao, Haifeng Shen, "Single Image Super-Resolution via a Holistic Attention Network", ECCV 2020, [arxiv](https://arxiv.org/abs/2008.08767)
>
> The code is built on RCAN (PyTorch) and tested on Ubuntu 16.04/18.04 environment (Python3.6, PyTorch_0.4.0, CUDA8.0, cuDNN5.1) with Titan X/1080Ti/Xp GPUs.
>
> ### Contents
>
> ________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________
>
> > 1. [Introduction](https://github.com/wwlCape/HAN#introduction)
> > 2. [Train](https://github.com/wwlCape/HAN#begin-to-train)
> > 3. [Test](https://github.com/wwlCape/HAN#begin-to-test)
> > 4. [Acknowledgements](https://github.com/wwlCape/HAN#Acknowledgements)
>
> ### Introduction
>
> Informative features play a crucial role in the single image super-resolution task. Channel attention has been demonstrated to be effective for preserving information-rich features in each layer. However, channel attention treats each convolution layer as a separate process that misses the correlation among different layers. To address this problem, we propose a new holistic attention network (HAN), which consists of a layer attention module (LAM) and a channel-spatial attention module (CSAM), to model the holistic interdependencies among layers, channels, and positions. Specifically, the proposed LAM adaptively emphasizes hierarchical features by considering correlations among layers. Meanwhile, CSAM learns the confidence at all the positions of each channel to selectively capture more informative features. Extensive experiments demonstrate that the proposed HAN performs favorably against the state-of-the-art single image super- resolution approaches.
>
>
> Train
> Prepare training data
> Download DIV2K training data (800 training + 100 validtion images) from DIV2K dataset.
>
> ### Begin to train
>
> (optional) Download models for our paper and place them in '/HAN/experiment/HAN'. All the models (BIX2/3/4/8, BDX3) can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/17cLcPCDLuBV5_5-ngd0vXIDp6rebIMG1). You can use scripts in file 'demo.sh' to train models for our paper.
>
> ```python
> BI, scale 2, 3, 4, 8
> #HAN BI model (x2)
> 
> python main.py --template HAN --save HANx2 --scale 2 --reset --save_results --patch_size 96 --pre_train ../experiment/model/RCAN_BIX2.pt
> 
> #HAN BI model (x3)
> 
> python main.py --template HAN --save HANx3 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
> 
> #HAN BI model (x4)
> 
> python main.py --template HAN --save HANx4 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
> 
> #HAN BI model (x8)
> 
> python main.py --template HAN --save HANx8 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
> 
> 
> ```
>
> ### Begin to Test
>
> ```python
> Quick start
> 
> Download models for our paper and place them in '/experiment/HAN'.
> 
> Cd to '/HAN/src', run the following scripts.
> #test
> python main.py --template HAN --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 2 --pre_train ../experiment/HAN/HAN_BIX2.pt --test_only --save HANx2_test --save_results
> ```
>
> All the models (BIX2/3/4/8, BDX3) can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/17cLcPCDLuBV5_5-ngd0vXIDp6rebIMG1).
>
> The whole test pipeline 
>
> 1.Prepare test data.
>
> Place the original test sets in '/dataset/x4/test'.
>
> Run 'Prepare_TestData_HR_LR.m' in Matlab to generate HR/LR images with different degradation models.
>
> 2.Conduct image SR.
>
> See Quick start
>
> 3.Evaluate the results.
>
> Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.
>
> ### Acknowledgements
>
> This code is built on [RCAN](https://github.com/yulunzhang/RCAN). We thank the authors for sharing their codes of RCAN  [PyTorch version](https://github.com/yulunzhang/RCAN).



================================================
FILE: experiment/.gitignore
================================================
*
!.gitignore
!/model/*.pt


================================================
FILE: src/__init__.py
================================================


================================================
FILE: src/data/__init__.py
================================================
from importlib import import_module
#from dataloader import MSDataLoader
from torch.utils.data import dataloader
from torch.utils.data import ConcatDataset

# This is a simple wrapper function for ConcatDataset
class MyConcatDataset(ConcatDataset):
    def __init__(self, datasets):
        super(MyConcatDataset, self).__init__(datasets)
        self.train = datasets[0].train

    def set_scale(self, idx_scale):
        for d in self.datasets:
            if hasattr(d, 'set_scale'): d.set_scale(idx_scale)

class Data:
    def __init__(self, args):
        self.loader_train = None
        if not args.test_only:
            datasets = []
            for d in args.data_train:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                datasets.append(getattr(m, module_name)(args, name=d))

            self.loader_train = dataloader.DataLoader(
                MyConcatDataset(datasets),
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=not args.cpu,
                num_workers=args.n_threads,
            )

        self.loader_test = []
        for d in args.data_test:
            if d in ['Val20', 'Set20', 'Set5', 'Set14', 'B100', 'Urban100','Manga109']:
                m = import_module('data.benchmark')
                testset = getattr(m, 'Benchmark')(args, train=False, name=d)
            else:
                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'
                m = import_module('data.' + module_name.lower())
                testset = getattr(m, module_name)(args, train=False, name=d)

            self.loader_test.append(
                dataloader.DataLoader(
                    testset,
                    batch_size=1,
                    shuffle=False,
                    pin_memory=not args.cpu,
                    num_workers=args.n_threads,
                )
            )


================================================
FILE: src/data/benchmark.py
================================================
import os

from data import common
from data import srdata

import numpy as np

import torch
import torch.utils.data as data
import glob
import pdb

class Benchmark(srdata.SRData):
    def __init__(self, args, name='', train=True, benchmark=True):
        super(Benchmark, self).__init__(
            args, name=name, train=train, benchmark=True)

    def _scan(self):
        list_hr = []
        list_lr = [[] for _ in self.scale]
        for entry in os.scandir(self.dir_hr):
            filename = os.path.splitext(entry.name)[0]
            if "HR" in filename:
                list_hr.append(os.path.join(self.dir_hr, filename + self.ext))
        #pdb.set_trace()
        for entry in os.scandir(self.dir_lr):
            filename = os.path.splitext(entry.name)[0]
            if "LR" in filename:
                for si, s in enumerate(self.scale):
                    list_lr[si].append(os.path.join(
                        self.dir_lr, filename + self.ext))

        list_hr.sort()
        for l in list_lr:
            l.sort()

        return list_hr, list_lr

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data, self.name)
        self.all_files = glob.glob(os.path.join(self.apath, 'HR', "*.png"))
        #self.dir_lr = os.path.join(dir_data, self.name, 'Test/3')
        #self.dir_hr = os.path.join(dir_data, self.name, 'Test/3')
        self.dir_lr = os.path.join(dir_data, self.name, 'LR','X4')
        self.dir_hr = os.path.join(dir_data, self.name, 'HR')
        #self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
        self.ext = '.png'

================================================
FILE: src/data/common.py
================================================
import random

import numpy as np
import skimage.color as sc

import torch

def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):
    ih, iw = args[0].shape[:2]

    if not input_large:
        p = scale if multi else 1
        tp = p * patch_size
        ip = tp // scale
    else:
        tp = patch_size
        ip = patch_size

    ix = random.randrange(0, iw - ip + 1)
    iy = random.randrange(0, ih - ip + 1)

    if not input_large:
        tx, ty = scale * ix, scale * iy
    else:
        tx, ty = ix, iy

    ret = [
        args[0][iy:iy + ip, ix:ix + ip, :],
        *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]
    ]

    return ret

def set_channel(*args, n_channels=3):
    def _set_channel(img):
        if img.ndim == 2:
            img = np.expand_dims(img, axis=2)

        c = img.shape[2]
        if n_channels == 1 and c == 3:
            img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
        elif n_channels == 3 and c == 1:
            img = np.concatenate([img] * n_channels, 2)

        return img

    return [_set_channel(a) for a in args]

def np2Tensor(*args, rgb_range=255):
    def _np2Tensor(img):
        np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
        tensor = torch.from_numpy(np_transpose).float()
        tensor.mul_(rgb_range / 255)

        return tensor

    return [_np2Tensor(a) for a in args]

def augment(*args, hflip=True, rot=True):
    hflip = hflip and random.random() < 0.5
    vflip = rot and random.random() < 0.5
    rot90 = rot and random.random() < 0.5

    def _augment(img):
        if hflip: img = img[:, ::-1, :]
        if vflip: img = img[::-1, :, :]
        if rot90: img = img.transpose(1, 0, 2)
        
        return img

    return [_augment(a) for a in args]



================================================
FILE: src/data/demo.py
================================================
import os

from data import common

import numpy as np
import imageio

import torch
import torch.utils.data as data

class Demo(data.Dataset):
    def __init__(self, args, name='Demo', train=False, benchmark=False):
        self.args = args
        self.name = name
        self.scale = args.scale
        self.idx_scale = 0
        self.train = False
        self.benchmark = benchmark

        self.filelist = []
        for f in os.listdir(args.dir_demo):
            if f.find('.png') >= 0 or f.find('.jp') >= 0:
                self.filelist.append(os.path.join(args.dir_demo, f))
        self.filelist.sort()

    def __getitem__(self, idx):
        filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]
        lr = imageio.imread(self.filelist[idx])
        lr, = common.set_channel(lr, n_channels=self.args.n_colors)
        lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)

        return lr_t, -1, filename

    def __len__(self):
        return len(self.filelist)

    def set_scale(self, idx_scale):
        self.idx_scale = idx_scale



================================================
FILE: src/data/div2k.py
================================================
import os
from data import srdata

class DIV2K(srdata.SRData):
    def __init__(self, args, name='DIV2K', train=True, benchmark=False):
        data_range = [r.split('-') for r in args.data_range.split('/')]
        if train:
            data_range = data_range[0]
        else:
            if args.test_only and len(data_range) == 1:
                data_range = data_range[0]
            else:
                data_range = data_range[1]

        self.begin, self.end = list(map(lambda x: int(x), data_range))
        super(DIV2K, self).__init__(
            args, name=name, train=train, benchmark=benchmark
        )

    def _scan(self):
        names_hr, names_lr = super(DIV2K, self)._scan()
        names_hr = names_hr[self.begin - 1:self.end]
        names_lr = [n[self.begin - 1:self.end] for n in names_lr]

        return names_hr, names_lr

    def _set_filesystem(self, dir_data):
        super(DIV2K, self)._set_filesystem(dir_data)
        self.apath = dir_data
        self.dir_hr = os.path.join(self.apath, 'TrainHR')
        self.dir_lr = os.path.join(self.apath, 'TrainLR')
        #self.dir_lr = os.path.join(self.apath, 'dataset/DIV2K_train_HR')
        if self.input_large: self.dir_lr += 'L'



================================================
FILE: src/data/div2kjpeg.py
================================================
import os
from data import srdata
from data import div2k

class DIV2KJPEG(div2k.DIV2K):
    def __init__(self, args, name='', train=True, benchmark=False):
        self.q_factor = int(name.replace('DIV2K-Q', ''))
        super(DIV2KJPEG, self).__init__(
            args, name=name, train=train, benchmark=benchmark
        )

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data, 'DIV2K')
        self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')
        self.dir_lr = os.path.join(
            self.apath, 'DIV2K_Q{}'.format(self.q_factor)
        )
        if self.input_large: self.dir_lr += 'L'
        self.ext = ('.png', '.jpg')



================================================
FILE: src/data/sr291.py
================================================
from data import srdata

class SR291(srdata.SRData):
    def __init__(self, args, name='SR291', train=True, benchmark=False):
        super(SR291, self).__init__(args, name=name)



================================================
FILE: src/data/srdata.py
================================================
import os
import glob
import random
import pickle

from data import common

import numpy as np
import imageio
import torch
import torch.utils.data as data
import pdb
#import pdb

class SRData(data.Dataset):
    def __init__(self, args, name='', train=True, benchmark=False):
        self.args = args
        self.name = name
        self.train = train
        self.split = 'train' if train else 'test'
        self.do_eval = True
        self.benchmark = benchmark
        self.input_large = (args.model == 'VDSR')
        self.scale = args.scale
        self.idx_scale = 0
        
        self._set_filesystem(args.dir_data)
        if args.ext.find('img') < 0:
            path_bin = os.path.join(self.apath, 'bin')
            os.makedirs(path_bin, exist_ok=True)

        list_hr, list_lr = self._scan()
        if args.ext.find('img') >= 0 or benchmark:
            self.images_hr, self.images_lr = list_hr, list_lr
        elif args.ext.find('sep') >= 0:
            os.makedirs(
                self.dir_hr.replace(self.apath, path_bin),
                exist_ok=True
            )
            for s in self.scale:
                os.makedirs(
                    os.path.join(
                        self.dir_lr.replace(self.apath, path_bin),
                        'X{}'.format(s)
                    ),
                    exist_ok=True
                )
            
            self.images_hr, self.images_lr = [], [[] for _ in self.scale]
            for h in list_hr:
                b = h.replace(self.apath, path_bin)
                b = b.replace(self.ext[0], '.pt')
                self.images_hr.append(b)
                self._check_and_load(args.ext, h, b, verbose=True) 
            for i, ll in enumerate(list_lr):
                for l in ll:
                    #pdb.set_trace()
                    b = l.replace(self.apath, path_bin)
                    b = b.replace(self.ext[1], '.pt')
                    self.images_lr[i].append(b)
                    self._check_and_load(args.ext, l, b, verbose=True) 
        if train:
            n_patches = args.batch_size * args.test_every
            n_images = len(args.data_train) * len(self.images_hr)
            if n_images == 0:
                self.repeat = 0
            else:
                self.repeat = max(n_patches // n_images, 1)

    # Below functions as used to prepare images
    def _scan(self):
        names_hr = sorted(
            glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))
        )
        names_lr = [[] for _ in self.scale]
        for f in names_hr:
            filename,_ = os.path.splitext(os.path.basename(f))[0].split('_')
            for si, s in enumerate(self.scale):
                names_lr[si].append(os.path.join(
                    self.dir_lr, 'X{}/{}{}{}'.format(
                        s, filename, '_LR', self.ext[1]
                    )
                ))

        return names_hr, names_lr

    def _set_filesystem(self, dir_data):
        self.apath = os.path.join(dir_data, self.name)
        self.dir_hr = os.path.join(self.apath, 'HR')
        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
        if self.input_large: self.dir_lr += 'L'
        self.ext = ('.png', '.png')

    def _check_and_load(self, ext, img, f, verbose=True):
        if not os.path.isfile(f) or ext.find('reset') >= 0:
            if verbose:
                print('Making a binary: {}'.format(f))
            with open(f, 'wb') as _f:
                pickle.dump(imageio.imread(img), _f)

    def __getitem__(self, idx):
        lr, hr, filename = self._load_file(idx)
        pair = self.get_patch(lr, hr)
        pair = common.set_channel(*pair, n_channels=self.args.n_colors)
        pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)

        return pair_t[0], pair_t[1], filename

    def __len__(self):
        if self.train:
            return len(self.images_hr) * self.repeat
        else:
            return len(self.images_hr)

    def _get_index(self, idx):
        if self.train:
            return idx % len(self.images_hr)
        else:
            return idx

    def _load_file(self, idx):
        idx = self._get_index(idx)
        f_hr = self.images_hr[idx]
        f_lr = self.images_lr[self.idx_scale][idx]
        #print('!!!!!!!!!',f_lr)
        #pdb.set_trace()

        filename, _ = os.path.splitext(os.path.basename(f_hr))
        if self.args.ext == 'img' or self.benchmark:
            hr = imageio.imread(f_hr)
            lr = imageio.imread(f_lr)
        elif self.args.ext.find('sep') >= 0:
            with open(f_hr, 'rb') as _f:
                hr = pickle.load(_f)
            with open(f_lr, 'rb') as _f:
                lr = pickle.load(_f)

        return lr, hr, filename

    def get_patch(self, lr, hr):
        scale = self.scale[self.idx_scale]
        if self.train:
            lr, hr = common.get_patch(
                lr, hr,
                patch_size=self.args.patch_size,
                scale=scale,
                multi=(len(self.scale) > 1),
                input_large=self.input_large
            )
            #print(hr.shape)
            if not self.args.no_augment: lr, hr = common.augment(lr, hr)
        else:
            ih, iw = lr.shape[:2]
            hr = hr[0:ih * scale, 0:iw * scale]

        return lr, hr

    def set_scale(self, idx_scale):
        if not self.input_large:
            self.idx_scale = idx_scale
        else:
            self.idx_scale = random.randint(0, len(self.scale) - 1)



================================================
FILE: src/data/video.py
================================================
import os

from data import common

import cv2
import numpy as np
import imageio

import torch
import torch.utils.data as data

class Video(data.Dataset):
    def __init__(self, args, name='Video', train=False, benchmark=False):
        self.args = args
        self.name = name
        self.scale = args.scale
        self.idx_scale = 0
        self.train = False
        self.do_eval = False
        self.benchmark = benchmark

        self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))
        self.vidcap = cv2.VideoCapture(args.dir_demo)
        self.n_frames = 0
        self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))

    def __getitem__(self, idx):
        success, lr = self.vidcap.read()
        if success:
            self.n_frames += 1
            lr, = common.set_channel(lr, n_channels=self.args.n_colors)
            lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)

            return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)
        else:
            vidcap.release()
            return None

    def __len__(self):
        return self.total_frames

    def set_scale(self, idx_scale):
        self.idx_scale = idx_scale



================================================
FILE: src/dataloader.py
================================================
import threading
import random

import torch
import torch.multiprocessing as multiprocessing
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import _utils
from torch.utils.data.dataloader import _DataLoaderIter

from torch.utils.data._utils import collate
from torch.utils.data._utils import signal_handling
from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL
from torch.utils.data._utils import ExceptionWrapper
from torch.utils.data._utils import IS_WINDOWS
from torch.utils.data._utils.worker import ManagerWatchdog

from torch._six import queue

def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):
    try:
        collate._use_shared_memory = True
        signal_handling._set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        data_queue.cancel_join_thread()

        if init_fn is not None:
            init_fn(worker_id)

        watchdog = ManagerWatchdog()

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                continue

            if r is None:
                assert done_event.is_set()
                return
            elif done_event.is_set():
                continue

            idx, batch_indices = r
            try:
                idx_scale = 0
                if len(scale) > 1 and dataset.train:
                    idx_scale = random.randrange(0, len(scale))
                    dataset.set_scale(idx_scale)

                samples = collate_fn([dataset[i] for i in batch_indices])
                samples.append(idx_scale)
            except Exception:
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else:
                data_queue.put((idx, samples))
                del samples

    except KeyboardInterrupt:
        pass

class _MSDataLoaderIter(_DataLoaderIter):

    def __init__(self, loader):
        self.dataset = loader.dataset
        self.scale = loader.scale
        self.collate_fn = loader.collate_fn
        self.batch_sampler = loader.batch_sampler
        self.num_workers = loader.num_workers
        self.pin_memory = loader.pin_memory and torch.cuda.is_available()
        self.timeout = loader.timeout

        self.sample_iter = iter(self.batch_sampler)

        base_seed = torch.LongTensor(1).random_().item()

        if self.num_workers > 0:
            self.worker_init_fn = loader.worker_init_fn
            self.worker_queue_idx = 0
            self.worker_result_queue = multiprocessing.Queue()
            self.batches_outstanding = 0
            self.worker_pids_set = False
            self.shutdown = False
            self.send_idx = 0
            self.rcvd_idx = 0
            self.reorder_dict = {}
            self.done_event = multiprocessing.Event()

            base_seed = torch.LongTensor(1).random_()[0]

            self.index_queues = []
            self.workers = []
            for i in range(self.num_workers):
                index_queue = multiprocessing.Queue()
                index_queue.cancel_join_thread()
                w = multiprocessing.Process(
                    target=_ms_loop,
                    args=(
                        self.dataset,
                        index_queue,
                        self.worker_result_queue,
                        self.done_event,
                        self.collate_fn,
                        self.scale,
                        base_seed + i,
                        self.worker_init_fn,
                        i
                    )
                )
                w.daemon = True
                w.start()
                self.index_queues.append(index_queue)
                self.workers.append(w)

            if self.pin_memory:
                self.data_queue = queue.Queue()
                pin_memory_thread = threading.Thread(
                    target=_utils.pin_memory._pin_memory_loop,
                    args=(
                        self.worker_result_queue,
                        self.data_queue,
                        torch.cuda.current_device(),
                        self.done_event
                    )
                )
                pin_memory_thread.daemon = True
                pin_memory_thread.start()
                self.pin_memory_thread = pin_memory_thread
            else:
                self.data_queue = self.worker_result_queue

            _utils.signal_handling._set_worker_pids(
                id(self), tuple(w.pid for w in self.workers)
            )
            _utils.signal_handling._set_SIGCHLD_handler()
            self.worker_pids_set = True

            for _ in range(2 * self.num_workers):
                self._put_indices()


class MSDataLoader(DataLoader):

    def __init__(self, cfg, *args, **kwargs):
        super(MSDataLoader, self).__init__(
            *args, **kwargs, num_workers=cfg.n_threads
        )
        self.scale = cfg.scale

    def __iter__(self):
        return _MSDataLoaderIter(self)



================================================
FILE: src/demo.sh
================================================
# EDSR baseline model (x2) + JPEG augmentation
#python3 main.py --model MatrixModel --scale 4 --patch_size 192 --save MatrixModelG7_x4 --reset --pre_train /media/zrh/cc9cb710-2fc7-4382-81ff-649502a83b92/EDSR-PyTorch-master/experiment/MatrixModelG6_x4/model/model_best.pt
#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75

# EDSR baseline model (x3) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]

# EDSR baseline model (x4) - from EDSR baseline model (x2)
#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]

# EDSR in the paper (x2)
#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset

# EDSR in the paper (x3) - from EDSR (x2)
#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]

# EDSR in the paper (x4) - from EDSR (x2)
#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]

# MDSR baseline model
#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models

# MDSR in the paper
#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models

# Standard benchmarks (Ex. EDSR_baseline_x4)
#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble

#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble

# Test your own images
#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results

# Advanced - Test with JPEG images 
#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results

# Advanced - Training with adversarial loss
#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download

# RDN BI model (x2)
#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset
# RDN BI model (x3)
#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset
# RDN BI model (x4)
#python main.py --scale 4 --save RDN9_D16C8G64_BIx4 --model RDN --epochs 400 --batch_size 16 --patch_size 128 --reset #--pre_train /home/visionx/wwl/project/EDSR-PyTorch-master/experiment/RDN7_D16C8G64_BIx4/model/model_best.pt

# RCAN_BIX2_G10R20P48, input=48x48, output=96x96
# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0
#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96
# RCAN_BIX3_G10R20P48, input=48x48, output=144x144
#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
# RCAN_BIX4_G10R20P48, input=48x48, output=192x192
#python main.py --template RCAN2 --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 8 --pre_train ../experiment/RCAN81_BIX8_G10R20P48/model/model_best.pt --test_only --save RCAN_test --save_results
#python main.py --template RCAN2 --save RCAN3_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
# RCAN_BIX8_G10R20P48, input=48x48, output=384x384
#python main.py --template RCAN2 --save RCAN81_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX8.pt

# HAN BI model (x2)
#python main.py --template HAN --save HANx2 --scale 2 --reset --save_results --patch_size 96 --pre_train ../experiment/model/RCAN_BIX2.pt
# HAN BI model (x3)
#python main.py --template HAN --save HANx3 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt
# HAN BI model (x4)
#python main.py --template HAN --save HANx4 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt
# HAN BI model (x8)
#python main.py --template HAN --save HANx8 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt
# Test HAN
#python main.py --template HAN --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 2 --pre_train ../experiment/HAN/HAN_BIX2.pt --test_only --save HANx2_test --save_results


================================================
FILE: src/loss/__init__.py
================================================
import os
from importlib import import_module

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class Loss(nn.modules.loss._Loss):
    def __init__(self, args, ckp):
        super(Loss, self).__init__()
        print('Preparing loss function:')

        self.n_GPUs = args.n_GPUs
        self.loss = []
        self.loss_module = nn.ModuleList()
        for loss in args.loss.split('+'):
            weight, loss_type = loss.split('*')
            if loss_type == 'MSE':
                loss_function = nn.MSELoss()
            elif loss_type == 'L1':
                loss_function = nn.L1Loss()
            elif loss_type.find('VGG') >= 0:
                module = import_module('loss.vgg')
                loss_function = getattr(module, 'VGG')(
                    loss_type[3:],
                    rgb_range=args.rgb_range
                )
            elif loss_type.find('GAN') >= 0:
                module = import_module('loss.adversarial')
                loss_function = getattr(module, 'Adversarial')(
                    args,
                    loss_type
                )

            self.loss.append({
                'type': loss_type,
                'weight': float(weight),
                'function': loss_function}
            )
            if loss_type.find('GAN') >= 0:
                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})

        if len(self.loss) > 1:
            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})

        for l in self.loss:
            if l['function'] is not None:
                print('{:.3f} * {}'.format(l['weight'], l['type']))
                self.loss_module.append(l['function'])

        self.log = torch.Tensor()

        device = torch.device('cpu' if args.cpu else 'cuda')
        self.loss_module.to(device)
        if args.precision == 'half': self.loss_module.half()
        if not args.cpu and args.n_GPUs > 1:
            self.loss_module = nn.DataParallel(
                self.loss_module, range(args.n_GPUs)
            )

        if args.load != '': self.load(ckp.dir, cpu=args.cpu)

    def forward(self, sr, hr):
        losses = []
        for i, l in enumerate(self.loss):
            if l['function'] is not None:
                loss = l['function'](sr, hr)
                effective_loss = l['weight'] * loss
                losses.append(effective_loss)
                self.log[-1, i] += effective_loss.item()
            elif l['type'] == 'DIS':
                self.log[-1, i] += self.loss[i - 1]['function'].loss

        loss_sum = sum(losses)
        if len(self.loss) > 1:
            self.log[-1, -1] += loss_sum.item()

        return loss_sum

    def step(self):
        for l in self.get_loss_module():
            if hasattr(l, 'scheduler'):
                l.scheduler.step()

    def start_log(self):
        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))

    def end_log(self, n_batches):
        self.log[-1].div_(n_batches)

    def display_loss(self, batch):
        n_samples = batch + 1
        log = []
        for l, c in zip(self.loss, self.log[-1]):
            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))

        return ''.join(log)

    def plot_loss(self, apath, epoch):
        axis = np.linspace(1, epoch, epoch)
        for i, l in enumerate(self.loss):
            label = '{} Loss'.format(l['type'])
            fig = plt.figure()
            plt.title(label)
            plt.plot(axis, self.log[:, i].numpy(), label=label)
            plt.legend()
            plt.xlabel('Epochs')
            plt.ylabel('Loss')
            plt.grid(True)
            plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))
            plt.close(fig)

    def get_loss_module(self):
        if self.n_GPUs == 1:
            return self.loss_module
        else:
            return self.loss_module.module

    def save(self, apath):
        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))

    def load(self, apath, cpu=False):
        if cpu:
            kwargs = {'map_location': lambda storage, loc: storage}
        else:
            kwargs = {}

        self.load_state_dict(torch.load(
            os.path.join(apath, 'loss.pt'),
            **kwargs
        ))
        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
        for l in self.get_loss_module():
            if hasattr(l, 'scheduler'):
                for _ in range(len(self.log)): l.scheduler.step()



================================================
FILE: src/loss/adversarial.py
================================================
import utility
from types import SimpleNamespace

from model import common
from loss import discriminator

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Adversarial(nn.Module):
    def __init__(self, args, gan_type):
        super(Adversarial, self).__init__()
        self.gan_type = gan_type
        self.gan_k = args.gan_k
        self.dis = discriminator.Discriminator(args)
        if gan_type == 'WGAN_GP':
            # see https://arxiv.org/pdf/1704.00028.pdf pp.4
            optim_dict = {
                'optimizer': 'ADAM',
                'betas': (0, 0.9),
                'epsilon': 1e-8,
                'lr': 1e-5,
                'weight_decay': args.weight_decay,
                'decay': args.decay,
                'gamma': args.gamma
            }
            optim_args = SimpleNamespace(**optim_dict)
        else:
            optim_args = args

        self.optimizer = utility.make_optimizer(optim_args, self.dis)

    def forward(self, fake, real):
        # updating discriminator...
        self.loss = 0
        fake_detach = fake.detach()     # do not backpropagate through G
        for _ in range(self.gan_k):
            self.optimizer.zero_grad()
            # d: B x 1 tensor
            d_fake = self.dis(fake_detach)
            d_real = self.dis(real)
            retain_graph = False
            if self.gan_type == 'GAN':
                loss_d = self.bce(d_real, d_fake)
            elif self.gan_type.find('WGAN') >= 0:
                loss_d = (d_fake - d_real).mean()
                if self.gan_type.find('GP') >= 0:
                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
                    hat.requires_grad = True
                    d_hat = self.dis(hat)
                    gradients = torch.autograd.grad(
                        outputs=d_hat.sum(), inputs=hat,
                        retain_graph=True, create_graph=True, only_inputs=True
                    )[0]
                    gradients = gradients.view(gradients.size(0), -1)
                    gradient_norm = gradients.norm(2, dim=1)
                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
                    loss_d += gradient_penalty
            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
            elif self.gan_type == 'RGAN':
                better_real = d_real - d_fake.mean(dim=0, keepdim=True)
                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)
                loss_d = self.bce(better_real, better_fake)
                retain_graph = True

            # Discriminator update
            self.loss += loss_d.item()
            loss_d.backward(retain_graph=retain_graph)
            self.optimizer.step()

            if self.gan_type == 'WGAN':
                for p in self.dis.parameters():
                    p.data.clamp_(-1, 1)

        self.loss /= self.gan_k

        # updating generator...
        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is
        if self.gan_type == 'GAN':
            label_real = torch.ones_like(d_fake_bp)
            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)
        elif self.gan_type.find('WGAN') >= 0:
            loss_g = -d_fake_bp.mean()
        elif self.gan_type == 'RGAN':
            better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)
            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)
            loss_g = self.bce(better_fake, better_real)

        # Generator loss
        return loss_g
    
    def state_dict(self, *args, **kwargs):
        state_discriminator = self.dis.state_dict(*args, **kwargs)
        state_optimizer = self.optimizer.state_dict()

        return dict(**state_discriminator, **state_optimizer)

    def bce(self, real, fake):
        label_real = torch.ones_like(real)
        label_fake = torch.zeros_like(fake)
        bce_real = F.binary_cross_entropy_with_logits(real, label_real)
        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)
        bce_loss = bce_real + bce_fake
        return bce_loss
               
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py


================================================
FILE: src/loss/discriminator.py
================================================
from model import common

import torch.nn as nn

class Discriminator(nn.Module):
    '''
        output is not normalized
    '''
    def __init__(self, args):
        super(Discriminator, self).__init__()

        in_channels = args.n_colors
        out_channels = 64
        depth = 7

        def _block(_in_channels, _out_channels, stride=1):
            return nn.Sequential(
                nn.Conv2d(
                    _in_channels,
                    _out_channels,
                    3,
                    padding=1,
                    stride=stride,
                    bias=False
                ),
                nn.BatchNorm2d(_out_channels),
                nn.LeakyReLU(negative_slope=0.2, inplace=True)
            )

        m_features = [_block(in_channels, out_channels)]
        for i in range(depth):
            in_channels = out_channels
            if i % 2 == 1:
                stride = 1
                out_channels *= 2
            else:
                stride = 2
            m_features.append(_block(in_channels, out_channels, stride=stride))

        patch_size = args.patch_size // (2**((depth + 1) // 2))
        m_classifier = [
            nn.Linear(out_channels * patch_size**2, 1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Linear(1024, 1)
        ]

        self.features = nn.Sequential(*m_features)
        self.classifier = nn.Sequential(*m_classifier)

    def forward(self, x):
        features = self.features(x)
        output = self.classifier(features.view(features.size(0), -1))

        return output



================================================
FILE: src/loss/vgg.py
================================================
from model import common

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class VGG(nn.Module):
    def __init__(self, conv_index, rgb_range=1):
        super(VGG, self).__init__()
        vgg_features = models.vgg19(pretrained=True).features
        modules = [m for m in vgg_features]
        if conv_index.find('22') >= 0:
            self.vgg = nn.Sequential(*modules[:8])
        elif conv_index.find('54') >= 0:
            self.vgg = nn.Sequential(*modules[:35])

        vgg_mean = (0.485, 0.456, 0.406)
        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, sr, hr):
        def _forward(x):
            x = self.sub_mean(x)
            x = self.vgg(x)
            return x
            
        vgg_sr = _forward(sr)
        with torch.no_grad():
            vgg_hr = _forward(hr.detach())

        loss = F.mse_loss(vgg_sr, vgg_hr)

        return loss


================================================
FILE: src/main.py
================================================
import torch

import utility
import data
import model
import loss
from option import args
from trainer import Trainer

torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)

def main():
    global model
    if args.data_test == ['video']:
        from videotester import VideoTester
        model = model.Model(args, checkpoint)
        t = VideoTester(args, model, checkpoint)
        t.test()
    else:
        if checkpoint.ok:
            loader = data.Data(args)
            _model = model.Model(args, checkpoint)
            _loss = loss.Loss(args, checkpoint) if not args.test_only else None
            t = Trainer(args, loader, _model, _loss, checkpoint)
            while not t.terminate():
                t.train()
                t.test()

            checkpoint.done()

if __name__ == '__main__':
    main()


================================================
FILE: src/model/__init__.py
================================================
import os
from importlib import import_module

import torch
import torch.nn as nn
import torch.nn.parallel as P
import torch.utils.model_zoo
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'

class Model(nn.Module):
    def __init__(self, args, ckp):
        super(Model, self).__init__()
        print('Making model...')

        self.scale = args.scale
        self.idx_scale = 0
        self.input_large = (args.model == 'VDSR')
        self.self_ensemble = args.self_ensemble
        self.chop = args.chop
        self.precision = args.precision
        self.cpu = args.cpu
        self.device = torch.device('cpu' if args.cpu else 'cuda')
        self.n_GPUs = args.n_GPUs
        self.save_models = args.save_models

        module = import_module('model.' + args.model.lower())
        self.model = module.make_model(args).to(self.device)
        if args.precision == 'half':
            self.model.half()

        self.load(
            ckp.get_path('model'),
            pre_train=args.pre_train,
            resume=args.resume,
            cpu=args.cpu
        )
        print(self.model, file=ckp.log_file)

    def forward(self, x, idx_scale):
        self.idx_scale = idx_scale
        if hasattr(self.model, 'set_scale'):
            self.model.set_scale(idx_scale)

        if self.training:
            if self.n_GPUs > 1:
                return P.data_parallel(self.model, x, range(self.n_GPUs))
            else:
                return self.model(x)
        else:
            if self.chop:
                forward_function = self.forward_chop
            else:
                forward_function = self.model.forward

            if self.self_ensemble:
                return self.forward_x8(x, forward_function=forward_function)
            else:
                return forward_function(x)

    def save(self, apath, epoch, is_best=False):
        save_dirs = [os.path.join(apath, 'model_latest.pt')]

        if is_best:
            save_dirs.append(os.path.join(apath, 'model_best.pt'))
        if self.save_models:
            save_dirs.append(
                os.path.join(apath, 'model_{}.pt'.format(epoch))
            )

        for s in save_dirs:
            torch.save(self.model.state_dict(), s)

    def load(self, apath, pre_train='', resume=-1, cpu=False):
        load_from = None
        kwargs = {}
        if cpu:
            kwargs = {'map_location': lambda storage, loc: storage}

        if resume == -1:
            load_from = torch.load(
                os.path.join(apath, 'model_latest.pt'),
                **kwargs
            )
        elif resume == 0:
            if pre_train == 'download':
                print('Download the model')
                dir_model = os.path.join('..', 'models')
                os.makedirs(dir_model, exist_ok=True)
                load_from = torch.utils.model_zoo.load_url(
                    self.model.url,
                    model_dir=dir_model,
                    **kwargs
                )
            elif pre_train:
                print('Load the model from {}'.format(pre_train))
                load_from = torch.load(pre_train, **kwargs)
        else:
            load_from = torch.load(
                os.path.join(apath, 'model_{}.pt'.format(resume)),
                **kwargs
            )

        if load_from:
            self.model.load_state_dict(load_from, strict=False)

    def forward_chop(self, x, shave=10, min_size=160000):
        scale = self.scale[self.idx_scale]
        n_GPUs = min(self.n_GPUs, 4)
        b, c, h, w = x.size()
        h_half, w_half = h // 2, w // 2
        h_size, w_size = h_half + shave, w_half + shave
        lr_list = [
            x[:, :, 0:h_size, 0:w_size],
            x[:, :, 0:h_size, (w - w_size):w],
            x[:, :, (h - h_size):h, 0:w_size],
            x[:, :, (h - h_size):h, (w - w_size):w]]

        if w_size * h_size < min_size:
            sr_list = []
            for i in range(0, 4, n_GPUs):
                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
                sr_batch = self.model(lr_batch)
                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
        else:
            sr_list = [
                self.forward_chop(patch, shave=shave, min_size=min_size) \
                for patch in lr_list
            ]

        h, w = scale * h, scale * w
        h_half, w_half = scale * h_half, scale * w_half
        h_size, w_size = scale * h_size, scale * w_size
        shave *= scale

        output = x.new(b, c, h, w)
        output[:, :, 0:h_half, 0:w_half] \
            = sr_list[0][:, :, 0:h_half, 0:w_half]
        output[:, :, 0:h_half, w_half:w] \
            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
        output[:, :, h_half:h, 0:w_half] \
            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
        output[:, :, h_half:h, w_half:w] \
            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]

        return output

    def forward_x8(self, *args, forward_function=None):
        def _transform(v, op):
            if self.precision != 'single': v = v.float()

            v2np = v.data.cpu().numpy()
            if op == 'v':
                tfnp = v2np[:, :, :, ::-1].copy()
            elif op == 'h':
                tfnp = v2np[:, :, ::-1, :].copy()
            elif op == 't':
                tfnp = v2np.transpose((0, 1, 3, 2)).copy()

            ret = torch.Tensor(tfnp).to(self.device)
            if self.precision == 'half': ret = ret.half()

            return ret

        list_x = []
        for a in args:
            x = [a]
            for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])

            list_x.append(x)

        list_y = []
        for x in zip(*list_x):
            y = forward_function(*x)
            if not isinstance(y, list): y = [y]
            if not list_y:
                list_y = [[_y] for _y in y]
            else:
                for _list_y, _y in zip(list_y, y): _list_y.append(_y)

        for _list_y in list_y:
            for i in range(len(_list_y)):
                if i > 3:
                    _list_y[i] = _transform(_list_y[i], 't')
                if i % 4 > 1:
                    _list_y[i] = _transform(_list_y[i], 'h')
                if (i % 4) % 2 == 1:
                    _list_y[i] = _transform(_list_y[i], 'v')

        y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]
        if len(y) == 1: y = y[0]

        return y


================================================
FILE: src/model/common.py
================================================
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

class BasicBlock(nn.Sequential):
    def __init__(
        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
        bn=True, act=nn.ReLU(True)):

        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)

        super(BasicBlock, self).__init__(*m)

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res

class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)



================================================
FILE: src/model/dcn/__init__.py
================================================
from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
                          deform_conv, modulated_deform_conv)

__all__ = [
    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
    'modulated_deform_conv'
]


================================================
FILE: src/model/dcn/deform_conv.py
================================================
import math
import logging

import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from torch.nn.modules.utils import _pair

from . import deform_conv_cuda

logger = logging.getLogger('base')


class DeformConvFunction(Function):
    @staticmethod
    def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
                deformable_groups=1, im2col_step=64):
        if input is not None and input.dim() != 4:
            raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
                input.dim()))
        ctx.stride = _pair(stride)
        ctx.padding = _pair(padding)
        ctx.dilation = _pair(dilation)
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.im2col_step = im2col_step

        ctx.save_for_backward(input, offset, weight)

        output = input.new_empty(
            DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))

        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones

        if not input.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
            deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
                                                      ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
                                                      weight.size(2), ctx.stride[1], ctx.stride[0],
                                                      ctx.padding[1], ctx.padding[0],
                                                      ctx.dilation[1], ctx.dilation[0], ctx.groups,
                                                      ctx.deformable_groups, cur_im2col_step)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input, offset, weight = ctx.saved_tensors

        grad_input = grad_offset = grad_weight = None

        if not grad_output.is_cuda:
            raise NotImplementedError
        else:
            cur_im2col_step = min(ctx.im2col_step, input.shape[0])
            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'

            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                grad_input = torch.zeros_like(input)
                grad_offset = torch.zeros_like(offset)
                deform_conv_cuda.deform_conv_backward_input_cuda(
                    input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
                    ctx.deformable_groups, cur_im2col_step)

            if ctx.needs_input_grad[2]:
                grad_weight = torch.zeros_like(weight)
                deform_conv_cuda.deform_conv_backward_parameters_cuda(
                    input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
                    ctx.deformable_groups, 1, cur_im2col_step)

        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)

    @staticmethod
    def _output_size(input, weight, padding, dilation, stride):
        channels = weight.size(0)
        output_size = (input.size(0), channels)
        for d in range(input.dim() - 2):
            in_size = input.size(d + 2)
            pad = padding[d]
            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
            stride_ = stride[d]
            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
        if not all(map(lambda s: s > 0, output_size)):
            raise ValueError("convolution input is too small (output would be {})".format('x'.join(
                map(str, output_size))))
        return output_size


class ModulatedDeformConvFunction(Function):
    @staticmethod
    def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
                groups=1, deformable_groups=1):
        ctx.stride = stride
        ctx.padding = padding
        ctx.dilation = dilation
        ctx.groups = groups
        ctx.deformable_groups = deformable_groups
        ctx.with_bias = bias is not None
        if not ctx.with_bias:
            bias = input.new_empty(1)  # fake tensor
        if not input.is_cuda:
            raise NotImplementedError
        if weight.requires_grad or mask.requires_grad or offset.requires_grad \
                or input.requires_grad:
            ctx.save_for_backward(input, offset, mask, weight, bias)
        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
        ctx._bufs = [input.new_empty(0), input.new_empty(0)]
        deform_conv_cuda.modulated_deform_conv_cuda_forward(
            input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
            weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
            ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        if not grad_output.is_cuda:
            raise NotImplementedError
        input, offset, mask, weight, bias = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_offset = torch.zeros_like(offset)
        grad_mask = torch.zeros_like(mask)
        grad_weight = torch.zeros_like(weight)
        grad_bias = torch.zeros_like(bias)
        deform_conv_cuda.modulated_deform_conv_cuda_backward(
            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
            grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
            ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
            ctx.groups, ctx.deformable_groups, ctx.with_bias)
        if not ctx.with_bias:
            grad_bias = None

        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
                None)

    @staticmethod
    def _infer_shape(ctx, input, weight):
        n = input.size(0)
        channels_out = weight.size(0)
        height, width = input.shape[2:4]
        kernel_h, kernel_w = weight.shape[2:4]
        height_out = (height + 2 * ctx.padding - (ctx.dilation *
                                                  (kernel_h - 1) + 1)) // ctx.stride + 1
        width_out = (width + 2 * ctx.padding - (ctx.dilation *
                                                (kernel_w - 1) + 1)) // ctx.stride + 1
        return n, channels_out, height_out, width_out


deform_conv = DeformConvFunction.apply
modulated_deform_conv = ModulatedDeformConvFunction.apply


class DeformConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, deformable_groups=1, bias=False):
        super(DeformConv, self).__init__()

        assert not bias
        assert in_channels % groups == 0, \
            'in_channels {} cannot be divisible by groups {}'.format(
                in_channels, groups)
        assert out_channels % groups == 0, \
            'out_channels {} cannot be divisible by groups {}'.format(
                out_channels, groups)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _pair(padding)
        self.dilation = _pair(dilation)
        self.groups = groups
        self.deformable_groups = deformable_groups

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))

        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, offset):
        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
                           self.groups, self.deformable_groups)


class DeformConvPack(DeformConv):
    def __init__(self, *args, **kwargs):
        super(DeformConvPack, self).__init__(*args, **kwargs)

        self.conv_offset = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset.weight.data.zero_()
        self.conv_offset.bias.data.zero_()

    def forward(self, x):
        offset = self.conv_offset(x)
        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
                           self.groups, self.deformable_groups)


class ModulatedDeformConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
                 groups=1, deformable_groups=1, bias=True):
        super(ModulatedDeformConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = _pair(kernel_size)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.deformable_groups = deformable_groups
        self.with_bias = bias

        self.weight = nn.Parameter(
            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

    def forward(self, x, offset, mask):
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
                                     self.padding, self.dilation, self.groups,
                                     self.deformable_groups)


class ModulatedDeformConvPack(ModulatedDeformConv):
    def __init__(self, *args, extra_offset_mask=False, **kwargs):
        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)

        self.extra_offset_mask = extra_offset_mask
        self.conv_offset_mask = nn.Conv2d(
            self.in_channels,
            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
            bias=True)
        self.init_offset()

    def init_offset(self):
        self.conv_offset_mask.weight.data.zero_()
        self.conv_offset_mask.bias.data.zero_()

    def forward(self, x):
        if self.extra_offset_mask:
            # x = [input, features]
            out = self.conv_offset_mask(x[1])
            x = x[0]
        else:
            out = self.conv_offset_mask(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)

        offset_mean = torch.mean(torch.abs(offset))
        if offset_mean > 100:
            logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))

        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
                                     self.padding, self.dilation, self.groups,
                                     self.deformable_groups)


================================================
FILE: src/model/dcn/setup.py
================================================
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


def make_cuda_ext(name, sources):

    return CUDAExtension(
        name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
            'cxx': [],
            'nvcc': [
                '-D__CUDA_NO_HALF_OPERATORS__',
                '-D__CUDA_NO_HALF_CONVERSIONS__',
                '-D__CUDA_NO_HALF2_OPERATORS__',
            ]
        })


setup(
    name='deform_conv', ext_modules=[
        make_cuda_ext(name='deform_conv_cuda',
                      sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
    ], cmdclass={'build_ext': BuildExtension}, zip_safe=False)


================================================
FILE: src/model/dcn/src/deform_conv_cuda.cpp
================================================
// modify from
// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c

#include <torch/extension.h>

#include <cmath>
#include <vector>

void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
                       const int channels, const int height, const int width,
                       const int ksize_h, const int ksize_w, const int pad_h,
                       const int pad_w, const int stride_h, const int stride_w,
                       const int dilation_h, const int dilation_w,
                       const int parallel_imgs, const int deformable_group,
                       at::Tensor data_col);

void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
                       const int channels, const int height, const int width,
                       const int ksize_h, const int ksize_w, const int pad_h,
                       const int pad_w, const int stride_h, const int stride_w,
                       const int dilation_h, const int dilation_w,
                       const int parallel_imgs, const int deformable_group,
                       at::Tensor grad_im);

void deformable_col2im_coord(
    const at::Tensor data_col, const at::Tensor data_im,
    const at::Tensor data_offset, const int channels, const int height,
    const int width, const int ksize_h, const int ksize_w, const int pad_h,
    const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int parallel_imgs,
    const int deformable_group, at::Tensor grad_offset);

void modulated_deformable_im2col_cuda(
    const at::Tensor data_im, const at::Tensor data_offset,
    const at::Tensor data_mask, const int batch_size, const int channels,
    const int height_im, const int width_im, const int height_col,
    const int width_col, const int kernel_h, const int kenerl_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int deformable_group,
    at::Tensor data_col);

void modulated_deformable_col2im_cuda(
    const at::Tensor data_col, const at::Tensor data_offset,
    const at::Tensor data_mask, const int batch_size, const int channels,
    const int height_im, const int width_im, const int height_col,
    const int width_col, const int kernel_h, const int kenerl_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int deformable_group,
    at::Tensor grad_im);

void modulated_deformable_col2im_coord_cuda(
    const at::Tensor data_col, const at::Tensor data_im,
    const at::Tensor data_offset, const at::Tensor data_mask,
    const int batch_size, const int channels, const int height_im,
    const int width_im, const int height_col, const int width_col,
    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
    const int stride_h, const int stride_w, const int dilation_h,
    const int dilation_w, const int deformable_group, at::Tensor grad_offset,
    at::Tensor grad_mask);

void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
                 int padW, int dilationH, int dilationW, int group,
                 int deformable_group) {
  AT_CHECK(weight.ndimension() == 4,
           "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
           "but got: %s",
           weight.ndimension());

  AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

  AT_CHECK(kW > 0 && kH > 0,
           "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
           kW);

  AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
           "kernel size should be consistent with weight, ",
           "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
           kW, weight.size(2), weight.size(3));

  AT_CHECK(dW > 0 && dH > 0,
           "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);

  AT_CHECK(
      dilationW > 0 && dilationH > 0,
      "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
      dilationH, dilationW);

  int ndim = input.ndimension();
  int dimf = 0;
  int dimh = 1;
  int dimw = 2;

  if (ndim == 4) {
    dimf++;
    dimh++;
    dimw++;
  }

  AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
           ndim);

  long nInputPlane = weight.size(1) * group;
  long inputHeight = input.size(dimh);
  long inputWidth = input.size(dimw);
  long nOutputPlane = weight.size(0);
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;

  AT_CHECK(nInputPlane % deformable_group == 0,
           "input channels must divide deformable group size");

  if (outputWidth < 1 || outputHeight < 1)
    AT_ERROR(
        "Given input size: (%ld x %ld x %ld). "
        "Calculated output size: (%ld x %ld x %ld). Output size is too small",
        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
        outputWidth);

  AT_CHECK(input.size(1) == nInputPlane,
           "invalid number of input planes, expected: %d, but got: %d",
           nInputPlane, input.size(1));

  AT_CHECK((inputHeight >= kH && inputWidth >= kW),
           "input image is smaller than kernel");

  AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
           "invalid spatial size of offset, expected height: %d width: %d, but "
           "got height: %d width: %d",
           outputHeight, outputWidth, offset.size(2), offset.size(3));

  AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
           "invalid number of channels of offset");

  if (gradOutput != NULL) {
    AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
             "invalid number of gradOutput planes, expected: %d, but got: %d",
             nOutputPlane, gradOutput->size(dimf));

    AT_CHECK((gradOutput->size(dimh) == outputHeight &&
              gradOutput->size(dimw) == outputWidth),
             "invalid size of gradOutput, expected height: %d width: %d , but "
             "got height: %d width: %d",
             outputHeight, outputWidth, gradOutput->size(dimh),
             gradOutput->size(dimw));
  }
}

int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
                             at::Tensor offset, at::Tensor output,
                             at::Tensor columns, at::Tensor ones, int kW,
                             int kH, int dW, int dH, int padW, int padH,
                             int dilationW, int dilationH, int group,
                             int deformable_group, int im2col_step) {
  // todo: resize columns to include im2col: done
  // todo: add im2col_step as input
  // todo: add new output buffer and transpose it to output (or directly
  // transpose output) todo: possibly change data indexing because of
  // parallel_imgs

  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
              dilationH, dilationW, group, deformable_group);

  input = input.contiguous();
  offset = offset.contiguous();
  weight = weight.contiguous();

  int batch = 1;
  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input.unsqueeze_(0);
    offset.unsqueeze_(0);
  }

  // todo: assert batchsize dividable by im2col_step

  long batchSize = input.size(0);
  long nInputPlane = input.size(1);
  long inputHeight = input.size(2);
  long inputWidth = input.size(3);

  long nOutputPlane = weight.size(0);

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

  AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
                        outputHeight, outputWidth});
  columns = at::zeros(
      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
    ones = at::ones({outputHeight, outputWidth}, input.options());
  }

  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth});
  offset =
      offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  at::Tensor output_buffer =
      at::zeros({batchSize / im2col_step, nOutputPlane,
                 im2col_step * outputHeight, outputWidth},
                output.options());

  output_buffer = output_buffer.view(
      {output_buffer.size(0), group, output_buffer.size(1) / group,
       output_buffer.size(2), output_buffer.size(3)});

  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                      dilationW, im2col_step, deformable_group, columns);

    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});

    for (int g = 0; g < group; g++) {
      output_buffer[elt][g] = output_buffer[elt][g]
                                  .flatten(1)
                                  .addmm_(weight[g].flatten(1), columns[g])
                                  .view_as(output_buffer[elt][g]);
    }
  }

  output_buffer = output_buffer.view(
      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
       output_buffer.size(3), output_buffer.size(4)});

  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
                                      im2col_step, outputHeight, outputWidth});
  output_buffer.transpose_(1, 2);
  output.copy_(output_buffer);
  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});

  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
  offset = offset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  if (batch == 0) {
    output = output.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
  }

  return 1;
}

int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
                                    at::Tensor gradOutput, at::Tensor gradInput,
                                    at::Tensor gradOffset, at::Tensor weight,
                                    at::Tensor columns, int kW, int kH, int dW,
                                    int dH, int padW, int padH, int dilationW,
                                    int dilationH, int group,
                                    int deformable_group, int im2col_step) {
  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
              dilationH, dilationW, group, deformable_group);

  input = input.contiguous();
  offset = offset.contiguous();
  gradOutput = gradOutput.contiguous();
  weight = weight.contiguous();

  int batch = 1;

  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input = input.view({1, input.size(0), input.size(1), input.size(2)});
    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
    gradOutput = gradOutput.view(
        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
  }

  long batchSize = input.size(0);
  long nInputPlane = input.size(1);
  long inputHeight = input.size(2);
  long inputWidth = input.size(3);

  long nOutputPlane = weight.size(0);

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

  AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
  columns = at::zeros(
      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());

  // change order of grad output
  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
                                nOutputPlane, outputHeight, outputWidth});
  gradOutput.transpose_(1, 2);

  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
                              inputHeight, inputWidth});
  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth});
  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
                                deformable_group * 2 * kH * kW, outputHeight,
                                outputWidth});
  offset =
      offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
    // divide into groups
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});
    gradOutput = gradOutput.view(
        {gradOutput.size(0), group, gradOutput.size(1) / group,
         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});

    for (int g = 0; g < group; g++) {
      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    gradOutput = gradOutput.view(
        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});

    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
                            dilationH, dilationW, im2col_step, deformable_group,
                            gradOffset[elt]);

    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                      dilationW, im2col_step, deformable_group, gradInput[elt]);
  }

  gradOutput.transpose_(1, 2);
  gradOutput =
      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});

  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
  gradOffset = gradOffset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
  offset = offset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  if (batch == 0) {
    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
    gradOffset =
        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
  }

  return 1;
}

int deform_conv_backward_parameters_cuda(
    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
    at::Tensor gradWeight,  // at::Tensor gradBias,
    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
    int padW, int padH, int dilationW, int dilationH, int group,
    int deformable_group, float scale, int im2col_step) {
  // todo: transpose and reshape outGrad
  // todo: reshape columns
  // todo: add im2col_step as input

  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
              padW, dilationH, dilationW, group, deformable_group);

  input = input.contiguous();
  offset = offset.contiguous();
  gradOutput = gradOutput.contiguous();

  int batch = 1;

  if (input.ndimension() == 3) {
    // Force batch
    batch = 0;
    input = input.view(
        at::IntList({1, input.size(0), input.size(1), input.size(2)}));
    gradOutput = gradOutput.view(
        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
  }

  long batchSize = input.size(0);
  long nInputPlane = input.size(1);
  long inputHeight = input.size(2);
  long inputWidth = input.size(3);

  long nOutputPlane = gradWeight.size(0);

  long outputWidth =
      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
  long outputHeight =
      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

  AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

  columns = at::zeros(
      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
      input.options());

  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
                                nOutputPlane, outputHeight, outputWidth});
  gradOutput.transpose_(1, 2);

  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
  gradOutputBuffer =
      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
                             outputHeight, outputWidth});
  gradOutputBuffer.copy_(gradOutput);
  gradOutputBuffer =
      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
                             im2col_step * outputHeight, outputWidth});

  gradOutput.transpose_(1, 2);
  gradOutput =
      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});

  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
                      inputHeight, inputWidth});
  offset =
      offset.view({batchSize / im2col_step, im2col_step,
                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  for (int elt = 0; elt < batchSize / im2col_step; elt++) {
    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
                      dilationW, im2col_step, deformable_group, columns);

    // divide into group
    gradOutputBuffer = gradOutputBuffer.view(
        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    gradWeight =
        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
                         gradWeight.size(2), gradWeight.size(3)});

    for (int g = 0; g < group; g++) {
      gradWeight[g] = gradWeight[g]
                          .flatten(1)
                          .addmm_(gradOutputBuffer[elt][g].flatten(1),
                                  columns[g].transpose(1, 0), 1.0, scale)
                          .view_as(gradWeight[g]);
    }
    gradOutputBuffer = gradOutputBuffer.view(
        {gradOutputBuffer.size(0),
         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
                                  gradWeight.size(2), gradWeight.size(3),
                                  gradWeight.size(4)});
  }

  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
  offset = offset.view(
      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});

  if (batch == 0) {
    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
    input = input.view({nInputPlane, inputHeight, inputWidth});
  }

  return 1;
}

void modulated_deform_conv_cuda_forward(
    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
    int kernel_h, int kernel_w, const int stride_h, const int stride_w,
    const int pad_h, const int pad_w, const int dilation_h,
    const int dilation_w, const int group, const int deformable_group,
    const bool with_bias) {
  AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
  AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

  const int batch = input.size(0);
  const int channels = input.size(1);
  const int height = input.size(2);
  const int width = input.size(3);

  const int channels_out = weight.size(0);
  const int channels_kernel = weight.size(1);
  const int kernel_h_ = weight.size(2);
  const int kernel_w_ = weight.size(3);

  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
             kernel_h_, kernel_w, kernel_h_, kernel_w_);
  if (channels != channels_kernel * group)
    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
             channels, channels_kernel * group);

  const int height_out =
      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int width_out =
      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < height_out * width_out) {
    // Resize plane and fill with ones...
    ones = at::ones({height_out, width_out}, input.options());
  }

  // resize output
  output = output.view({batch, channels_out, height_out, width_out}).zero_();
  // resize temporary columns
  columns =
      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
                input.options());

  output = output.view({output.size(0), group, output.size(1) / group,
                        output.size(2), output.size(3)});

  for (int b = 0; b < batch; b++) {
    modulated_deformable_im2col_cuda(
        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, columns);

    // divide into group
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});

    for (int g = 0; g < group; g++) {
      output[b][g] = output[b][g]
                         .flatten(1)
                         .addmm_(weight[g].flatten(1), columns[g])
                         .view_as(output[b][g]);
    }

    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
                          weight.size(3), weight.size(4)});
    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
  }

  output = output.view({output.size(0), output.size(1) * output.size(2),
                        output.size(3), output.size(4)});

  if (with_bias) {
    output += bias.view({1, bias.size(0), 1, 1});
  }
}

void modulated_deform_conv_cuda_backward(
    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
    at::Tensor offset, at::Tensor mask, at::Tensor columns,
    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
    const bool with_bias) {
  AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
  AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

  const int batch = input.size(0);
  const int channels = input.size(1);
  const int height = input.size(2);
  const int width = input.size(3);

  const int channels_kernel = weight.size(1);
  const int kernel_h_ = weight.size(2);
  const int kernel_w_ = weight.size(3);
  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
    AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
             kernel_h_, kernel_w, kernel_h_, kernel_w_);
  if (channels != channels_kernel * group)
    AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
             channels, channels_kernel * group);

  const int height_out =
      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
  const int width_out =
      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

  if (ones.ndimension() != 2 ||
      ones.size(0) * ones.size(1) < height_out * width_out) {
    // Resize plane and fill with ones...
    ones = at::ones({height_out, width_out}, input.options());
  }

  grad_input = grad_input.view({batch, channels, height, width});
  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
                      input.options());

  grad_output =
      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
                        grad_output.size(2), grad_output.size(3)});

  for (int b = 0; b < batch; b++) {
    // divide int group
    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    weight = weight.view({group, weight.size(0) / group, weight.size(1),
                          weight.size(2), weight.size(3)});

    for (int g = 0; g < group; g++) {
      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
                        grad_output[b][g].flatten(1), 0.0f, 1.0f);
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
                          weight.size(3), weight.size(4)});

    // gradient w.r.t. input coordinate data
    modulated_deformable_col2im_coord_cuda(
        columns, input[b], offset[b], mask[b], 1, channels, height, width,
        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
        grad_mask[b]);
    // gradient w.r.t. input data
    modulated_deformable_col2im_cuda(
        columns, offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, grad_input[b]);

    // gradient w.r.t. weight, dWeight should accumulate across the batch and
    // group
    modulated_deformable_im2col_cuda(
        input[b], offset[b], mask[b], 1, channels, height, width, height_out,
        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
        dilation_h, dilation_w, deformable_group, columns);

    columns = columns.view({group, columns.size(0) / group, columns.size(1)});
    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
                                    grad_weight.size(1), grad_weight.size(2),
                                    grad_weight.size(3)});
    if (with_bias)
      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});

    for (int g = 0; g < group; g++) {
      grad_weight[g] =
          grad_weight[g]
              .flatten(1)
              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
              .view_as(grad_weight[g]);
      if (with_bias) {
        grad_bias[g] =
            grad_bias[g]
                .view({-1, 1})
                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
                .view(-1);
      }
    }

    columns =
        columns.view({columns.size(0) * columns.size(1), columns.size(2)});
    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
                                    grad_weight.size(2), grad_weight.size(3),
                                    grad_weight.size(4)});
    if (with_bias)
      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
  }
  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
                                  grad_output.size(2), grad_output.size(3),
                                  grad_output.size(4)});
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
        "deform forward (CUDA)");
  m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
        "deform_conv_backward_input (CUDA)");
  m.def("deform_conv_backward_parameters_cuda",
        &deform_conv_backward_parameters_cuda,
        "deform_conv_backward_parameters (CUDA)");
  m.def("modulated_deform_conv_cuda_forward",
        &modulated_deform_conv_cuda_forward,
        "modulated deform conv forward (CUDA)");
  m.def("modulated_deform_conv_cuda_backward",
        &modulated_deform_conv_cuda_backward,
        "modulated deform conv backward (CUDA)");
}


================================================
FILE: src/model/dcn/src/deform_conv_cuda_kernel.cu
================================================
/*!
 ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
 *
 * COPYRIGHT
 *
 * All contributions by the University of California:
 * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
 * All rights reserved.
 *
 * All other contributions:
 * Copyright (c) 2014-2017, the respective contributors
 * All rights reserved.
 *
 * Caffe uses a shared copyright model: each contributor holds copyright over
 * their contributions to Caffe. The project versioning records all such
 * contribution and copyright details. If a contributor wants to further mark
 * their specific copyright on a particular contribution, they should indicate
 * their copyright solely in the commit message of the change when it is
 * committed.
 *
 * LICENSE
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice, this
 * list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * CONTRIBUTION AGREEMENT
 *
 * By contributing to the BVLC/caffe repository through pull-request, comment,
 * or otherwise, the contributor releases their content to the
 * license and copyright terms herein.
 *
 ***************** END Caffe Copyright Notice and Disclaimer ********************
 *
 * Copyright (c) 2018 Microsoft
 * Licensed under The MIT License [see LICENSE for details]
 * \file modulated_deformable_im2col.cuh
 * \brief Function definitions of converting an image to
 * column matrix based on kernel, padding, dilation, and offset.
 * These functions are mainly used in deformable convolution operators.
 * \ref: https://arxiv.org/abs/1703.06211
 * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
 */

// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu

#include <ATen/ATen.h>
#include <THC/THCAtomics.cuh>
#include <stdio.h>
#include <math.h>
#include <float.h>

using namespace at;

#define CUDA_KERNEL_LOOP(i, n)                                 \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;

inline int GET_BLOCKS(const int N)
{
  return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}

template <typename scalar_t>
__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
                                               const int height, const int width, scalar_t h, scalar_t w)
{

  int h_low = floor(h);
  int w_low = floor(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  scalar_t lh = h - h_low;
  scalar_t lw = w - w_low;
  scalar_t hh = 1 - lh, hw = 1 - lw;

  scalar_t v1 = 0;
  if (h_low >= 0 && w_low >= 0)
    v1 = bottom_data[h_low * data_width + w_low];
  scalar_t v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = bottom_data[h_low * data_width + w_high];
  scalar_t v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = bottom_data[h_high * data_width + w_low];
  scalar_t v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = bottom_data[h_high * data_width + w_high];

  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

template <typename scalar_t>
__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
                                        const int h, const int w, const int height, const int width)
{

  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
  {
    //empty
    return 0;
  }

  int argmax_h_low = floor(argmax_h);
  int argmax_w_low = floor(argmax_w);
  int argmax_h_high = argmax_h_low + 1;
  int argmax_w_high = argmax_w_low + 1;

  scalar_t weight = 0;
  if (h == argmax_h_low && w == argmax_w_low)
    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
  if (h == argmax_h_low && w == argmax_w_high)
    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
  if (h == argmax_h_high && w == argmax_w_low)
    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
  if (h == argmax_h_high && w == argmax_w_high)
    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
  return weight;
}

template <typename scalar_t>
__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
                                          const int height, const int width, const scalar_t *im_data,
                                          const int data_width, const int bp_dir)
{

  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
  {
    //empty
    return 0;
  }

  int argmax_h_low = floor(argmax_h);
  int argmax_w_low = floor(argmax_w);
  int argmax_h_high = argmax_h_low + 1;
  int argmax_w_high = argmax_w_low + 1;

  scalar_t weight = 0;

  if (bp_dir == 0)
  {
    if (argmax_h_low >= 0 && argmax_w_low >= 0)
      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
  }
  else if (bp_dir == 1)
  {
    if (argmax_h_low >= 0 && argmax_w_low >= 0)
      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
  }

  return weight;
}

template <typename scalar_t>
__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
                                             const int height, const int width, const int kernel_h, const int kernel_w,
                                             const int pad_h, const int pad_w, const int stride_h, const int stride_w,
                                             const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
                                             const int batch_size, const int num_channels, const int deformable_group,
                                             const int height_col, const int width_col,
                                             scalar_t *data_col)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    // index index of output matrix
    const int w_col = index % width_col;
    const int h_col = (index / width_col) % height_col;
    const int b_col = (index / width_col / height_col) % batch_size;
    const int c_im = (index / width_col / height_col) / batch_size;
    const int c_col = c_im * kernel_h * kernel_w;

    // compute deformable group index
    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;
    const int w_in = w_col * stride_w - pad_w;
    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
    //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;

    for (int i = 0; i < kernel_h; ++i)
    {
      for (int j = 0; j < kernel_w; ++j)
      {
        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
        scalar_t val = static_cast<scalar_t>(0);
        const scalar_t h_im = h_in + i * dilation_h + offset_h;
        const scalar_t w_im = w_in + j * dilation_w + offset_w;
        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
        {
          //const scalar_t map_h = i * dilation_h + offset_h;
          //const scalar_t map_w = j * dilation_w + offset_w;
          //const int cur_height = height - h_in;
          //const int cur_width = width - w_in;
          //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
        }
        *data_col_ptr = val;
        data_col_ptr += batch_size * height_col * width_col;
      }
    }
  }
}

void deformable_im2col(
    const at::Tensor data_im, const at::Tensor data_offset, const int channels,
    const int height, const int width, const int ksize_h, const int ksize_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w, const int parallel_imgs,
    const int deformable_group, at::Tensor data_col)
{
  // num_axes should be smaller than block size
  // todo: check parallel_imgs is correctly passed in
  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
  int num_kernels = channels * height_col * width_col * parallel_imgs;
  int channel_per_deformable_group = channels / deformable_group;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
        const scalar_t *data_im_ = data_im.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        scalar_t *data_col_ = data_col.data<scalar_t>();

        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
            channel_per_deformable_group, parallel_imgs, channels, deformable_group,
            height_col, width_col, data_col_);
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess)
  {
    printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
  }
}

template <typename scalar_t>
__global__ void deformable_col2im_gpu_kernel(
    const int n, const scalar_t *data_col, const scalar_t *data_offset,
    const int channels, const int height, const int width,
    const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w,
    const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int channel_per_deformable_group,
    const int batch_size, const int deformable_group,
    const int height_col, const int width_col,
    scalar_t *grad_im)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    const int j = (index / width_col / height_col / batch_size) % kernel_w;
    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
    // compute the start and end of the output

    const int deformable_group_index = c / channel_per_deformable_group;

    int w_out = index % width_col;
    int h_out = (index / width_col) % height_col;
    int b = (index / width_col / height_col) % batch_size;
    int w_in = w_out * stride_w - pad_w;
    int h_in = h_out * stride_h - pad_h;

    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
                                                        2 * kernel_h * kernel_w * height_col * width_col;
    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;

    const scalar_t cur_top_grad = data_col[index];
    const int cur_h = (int)cur_inv_h_data;
    const int cur_w = (int)cur_inv_w_data;
    for (int dy = -2; dy <= 2; dy++)
    {
      for (int dx = -2; dx <= 2; dx++)
      {
        if (cur_h + dy >= 0 && cur_h + dy < height &&
            cur_w + dx >= 0 && cur_w + dx < width &&
            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
            abs(cur_inv_w_data - (cur_w + dx)) < 1)
        {
          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
          scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
        }
      }
    }
  }
}

void deformable_col2im(
    const at::Tensor data_col, const at::Tensor data_offset, const int channels,
    const int height, const int width, const int ksize_h,
    const int ksize_w, const int pad_h, const int pad_w,
    const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int parallel_imgs, const int deformable_group,
    at::Tensor grad_im)
{

  // todo: make sure parallel_imgs is passed in correctly
  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
  int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
  int channel_per_deformable_group = channels / deformable_group;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
        const scalar_t *data_col_ = data_col.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        scalar_t *grad_im_ = grad_im.data<scalar_t>();

        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
            ksize_w, pad_h, pad_w, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group,
            parallel_imgs, deformable_group, height_col, width_col, grad_im_);
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess)
  {
    printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
  }
}

template <typename scalar_t>
__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
                                                   const scalar_t *data_im, const scalar_t *data_offset,
                                                   const int channels, const int height, const int width,
                                                   const int kernel_h, const int kernel_w,
                                                   const int pad_h, const int pad_w,
                                                   const int stride_h, const int stride_w,
                                                   const int dilation_h, const int dilation_w,
                                                   const int channel_per_deformable_group,
                                                   const int batch_size, const int offset_channels, const int deformable_group,
                                                   const int height_col, const int width_col, scalar_t *grad_offset)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    scalar_t val = 0;
    int w = index % width_col;
    int h = (index / width_col) % height_col;
    int c = (index / width_col / height_col) % offset_channels;
    int b = (index / width_col / height_col) / offset_channels;
    // compute the start and end of the output

    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
    const int col_step = kernel_h * kernel_w;
    int cnt = 0;
    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
                                                  batch_size * width_col * height_col;
    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
                                                channel_per_deformable_group / kernel_h / kernel_w * height * width;
    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
                                                        kernel_h * kernel_w * height_col * width_col;

    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;

    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
    {
      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
      const int bp_dir = offset_c % 2;

      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
      int w_out = col_pos % width_col;
      int h_out = (col_pos / width_col) % height_col;
      int w_in = w_out * stride_w - pad_w;
      int h_in = h_out * stride_h - pad_h;
      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
      scalar_t inv_h = h_in + i * dilation_h + offset_h;
      scalar_t inv_w = w_in + j * dilation_w + offset_w;
      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
      {
        inv_h = inv_w = -2;
      }
      const scalar_t weight = get_coordinate_weight(
          inv_h, inv_w,
          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
      val += weight * data_col_ptr[col_pos];
      cnt += 1;
    }

    grad_offset[index] = val;
  }
}

void deformable_col2im_coord(
    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
    const int channels, const int height, const int width, const int ksize_h,
    const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
    const int stride_w, const int dilation_h, const int dilation_w,
    const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
{

  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
  int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
        const scalar_t *data_col_ = data_col.data<scalar_t>();
        const scalar_t *data_im_ = data_im.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        scalar_t *grad_offset_ = grad_offset.data<scalar_t>();

        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
            ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group,
            parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
            height_col, width_col, grad_offset_);
      }));
}

template <typename scalar_t>
__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
                                         const int height, const int width, scalar_t h, scalar_t w)
{
  int h_low = floor(h);
  int w_low = floor(w);
  int h_high = h_low + 1;
  int w_high = w_low + 1;

  scalar_t lh = h - h_low;
  scalar_t lw = w - w_low;
  scalar_t hh = 1 - lh, hw = 1 - lw;

  scalar_t v1 = 0;
  if (h_low >= 0 && w_low >= 0)
    v1 = bottom_data[h_low * data_width + w_low];
  scalar_t v2 = 0;
  if (h_low >= 0 && w_high <= width - 1)
    v2 = bottom_data[h_low * data_width + w_high];
  scalar_t v3 = 0;
  if (h_high <= height - 1 && w_low >= 0)
    v3 = bottom_data[h_high * data_width + w_low];
  scalar_t v4 = 0;
  if (h_high <= height - 1 && w_high <= width - 1)
    v4 = bottom_data[h_high * data_width + w_high];

  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;

  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
  return val;
}

template <typename scalar_t>
__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
                                             const int h, const int w, const int height, const int width)
{
  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
  {
    //empty
    return 0;
  }

  int argmax_h_low = floor(argmax_h);
  int argmax_w_low = floor(argmax_w);
  int argmax_h_high = argmax_h_low + 1;
  int argmax_w_high = argmax_w_low + 1;

  scalar_t weight = 0;
  if (h == argmax_h_low && w == argmax_w_low)
    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
  if (h == argmax_h_low && w == argmax_w_high)
    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
  if (h == argmax_h_high && w == argmax_w_low)
    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
  if (h == argmax_h_high && w == argmax_w_high)
    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
  return weight;
}

template <typename scalar_t>
__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
                                               const int height, const int width, const scalar_t *im_data,
                                               const int data_width, const int bp_dir)
{
  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
  {
    //empty
    return 0;
  }

  int argmax_h_low = floor(argmax_h);
  int argmax_w_low = floor(argmax_w);
  int argmax_h_high = argmax_h_low + 1;
  int argmax_w_high = argmax_w_low + 1;

  scalar_t weight = 0;

  if (bp_dir == 0)
  {
    if (argmax_h_low >= 0 && argmax_w_low >= 0)
      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
  }
  else if (bp_dir == 1)
  {
    if (argmax_h_low >= 0 && argmax_w_low >= 0)
      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
  }

  return weight;
}

template <typename scalar_t>
__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
                                                       const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
                                                       const int height, const int width, const int kernel_h, const int kernel_w,
                                                       const int pad_h, const int pad_w,
                                                       const int stride_h, const int stride_w,
                                                       const int dilation_h, const int dilation_w,
                                                       const int channel_per_deformable_group,
                                                       const int batch_size, const int num_channels, const int deformable_group,
                                                       const int height_col, const int width_col,
                                                       scalar_t *data_col)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    // index index of output matrix
    const int w_col = index % width_col;
    const int h_col = (index / width_col) % height_col;
    const int b_col = (index / width_col / height_col) % batch_size;
    const int c_im = (index / width_col / height_col) / batch_size;
    const int c_col = c_im * kernel_h * kernel_w;

    // compute deformable group index
    const int deformable_group_index = c_im / channel_per_deformable_group;

    const int h_in = h_col * stride_h - pad_h;
    const int w_in = w_col * stride_w - pad_w;

    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;

    const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;

    for (int i = 0; i < kernel_h; ++i)
    {
      for (int j = 0; j < kernel_w; ++j)
      {
        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
        const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
        scalar_t val = static_cast<scalar_t>(0);
        const scalar_t h_im = h_in + i * dilation_h + offset_h;
        const scalar_t w_im = w_in + j * dilation_w + offset_w;
        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
        {
          //const float map_h = i * dilation_h + offset_h;
          //const float map_w = j * dilation_w + offset_w;
          //const int cur_height = height - h_in;
          //const int cur_width = width - w_in;
          //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
          val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
        }
        *data_col_ptr = val * mask;
        data_col_ptr += batch_size * height_col * width_col;
        //data_col_ptr += height_col * width_col;
      }
    }
  }
}

template <typename scalar_t>
__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
                                                       const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
                                                       const int channels, const int height, const int width,
                                                       const int kernel_h, const int kernel_w,
                                                       const int pad_h, const int pad_w,
                                                       const int stride_h, const int stride_w,
                                                       const int dilation_h, const int dilation_w,
                                                       const int channel_per_deformable_group,
                                                       const int batch_size, const int deformable_group,
                                                       const int height_col, const int width_col,
                                                       scalar_t *grad_im)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    const int j = (index / width_col / height_col / batch_size) % kernel_w;
    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
    // compute the start and end of the output

    const int deformable_group_index = c / channel_per_deformable_group;

    int w_out = index % width_col;
    int h_out = (index / width_col) % height_col;
    int b = (index / width_col / height_col) % batch_size;
    int w_in = w_out * stride_w - pad_w;
    int h_in = h_out * stride_h - pad_h;

    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
    const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;

    const scalar_t cur_top_grad = data_col[index] * mask;
    const int cur_h = (int)cur_inv_h_data;
    const int cur_w = (int)cur_inv_w_data;
    for (int dy = -2; dy <= 2; dy++)
    {
      for (int dx = -2; dx <= 2; dx++)
      {
        if (cur_h + dy >= 0 && cur_h + dy < height &&
            cur_w + dx >= 0 && cur_w + dx < width &&
            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
            abs(cur_inv_w_data - (cur_w + dx)) < 1)
        {
          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
          scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
        }
      }
    }
  }
}

template <typename scalar_t>
__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
                                                             const scalar_t *data_col, const scalar_t *data_im,
                                                             const scalar_t *data_offset, const scalar_t *data_mask,
                                                             const int channels, const int height, const int width,
                                                             const int kernel_h, const int kernel_w,
                                                             const int pad_h, const int pad_w,
                                                             const int stride_h, const int stride_w,
                                                             const int dilation_h, const int dilation_w,
                                                             const int channel_per_deformable_group,
                                                             const int batch_size, const int offset_channels, const int deformable_group,
                                                             const int height_col, const int width_col,
                                                             scalar_t *grad_offset, scalar_t *grad_mask)
{
  CUDA_KERNEL_LOOP(index, n)
  {
    scalar_t val = 0, mval = 0;
    int w = index % width_col;
    int h = (index / width_col) % height_col;
    int c = (index / width_col / height_col) % offset_channels;
    int b = (index / width_col / height_col) / offset_channels;
    // compute the start and end of the output

    const int deformable_group_index = c / (2 * kernel_h * kernel_w);
    const int col_step = kernel_h * kernel_w;
    int cnt = 0;
    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;

    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;

    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
    {
      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
      const int bp_dir = offset_c % 2;

      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
      int w_out = col_pos % width_col;
      int h_out = (col_pos / width_col) % height_col;
      int w_in = w_out * stride_w - pad_w;
      int h_in = h_out * stride_h - pad_h;
      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
      const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
      scalar_t inv_h = h_in + i * dilation_h + offset_h;
      scalar_t inv_w = w_in + j * dilation_w + offset_w;
      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
      {
        inv_h = inv_w = -2;
      }
      else
      {
        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
      }
      const scalar_t weight = dmcn_get_coordinate_weight(
          inv_h, inv_w,
          height, width, data_im_ptr + cnt * height * width, width, bp_dir);
      val += weight * data_col_ptr[col_pos] * mask;
      cnt += 1;
    }
    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
    grad_offset[index] = val;
    if (offset_c % 2 == 0)
      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
  }
}

void modulated_deformable_im2col_cuda(
    const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
    const int batch_size, const int channels, const int height_im, const int width_im,
    const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int deformable_group, at::Tensor data_col)
{
  // num_axes should be smaller than block size
  const int channel_per_deformable_group = channels / deformable_group;
  const int num_kernels = channels * batch_size * height_col * width_col;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
        const scalar_t *data_im_ = data_im.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        const scalar_t *data_mask_ = data_mask.data<scalar_t>();
        scalar_t *data_col_ = data_col.data<scalar_t>();

        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
            batch_size, channels, deformable_group, height_col, width_col, data_col_);
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess)
  {
    printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
  }
}

void modulated_deformable_col2im_cuda(
    const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
    const int batch_size, const int channels, const int height_im, const int width_im,
    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int deformable_group, at::Tensor grad_im)
{

  const int channel_per_deformable_group = channels / deformable_group;
  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
        const scalar_t *data_col_ = data_col.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        const scalar_t *data_mask_ = data_mask.data<scalar_t>();
        scalar_t *grad_im_ = grad_im.data<scalar_t>();

        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
            kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group,
            batch_size, deformable_group, height_col, width_col, grad_im_);
      }));

  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess)
  {
    printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
  }
}

void modulated_deformable_col2im_coord_cuda(
    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
    const int batch_size, const int channels, const int height_im, const int width_im,
    const int height_col, const int width_col, const int kernel_h, const int kernel_w,
    const int pad_h, const int pad_w, const int stride_h, const int stride_w,
    const int dilation_h, const int dilation_w,
    const int deformable_group,
    at::Tensor grad_offset, at::Tensor grad_mask)
{
  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
        const scalar_t *data_col_ = data_col.data<scalar_t>();
        const scalar_t *data_im_ = data_im.data<scalar_t>();
        const scalar_t *data_offset_ = data_offset.data<scalar_t>();
        const scalar_t *data_mask_ = data_mask.data<scalar_t>();
        scalar_t *grad_offset_ = grad_offset.data<scalar_t>();
        scalar_t *grad_mask_ = grad_mask.data<scalar_t>();

        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(
            num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
            dilation_h, dilation_w, channel_per_deformable_group,
            batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
            grad_offset_, grad_mask_);
      }));
  cudaError_t err = cudaGetLastError();
  if (err != cudaSuccess)
  {
    printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
  }
}


================================================
FILE: src/model/ddbpn.py
================================================
# Deep Back-Projection Networks For Super-Resolution
# https://arxiv.org/abs/1803.02735

from model import common

import torch
import torch.nn as nn


def make_model(args, parent=False):
    return DDBPN(args)

def projection_conv(in_channels, out_channels, scale, up=True):
    kernel_size, stride, padding = {
        2: (6, 2, 2),
        4: (8, 4, 2),
        8: (12, 8, 2)
    }[scale]
    if up:
        conv_f = nn.ConvTranspose2d
    else:
        conv_f = nn.Conv2d

    return conv_f(
        in_channels, out_channels, kernel_size,
        stride=stride, padding=padding
    )

class DenseProjection(nn.Module):
    def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
        super(DenseProjection, self).__init__()
        if bottleneck:
            self.bottleneck = nn.Sequential(*[
                nn.Conv2d(in_channels, nr, 1),
                nn.PReLU(nr)
            ])
            inter_channels = nr
        else:
            self.bottleneck = None
            inter_channels = in_channels

        self.conv_1 = nn.Sequential(*[
            projection_conv(inter_channels, nr, scale, up),
            nn.PReLU(nr)
        ])
        self.conv_2 = nn.Sequential(*[
            projection_conv(nr, inter_channels, scale, not up),
            nn.PReLU(inter_channels)
        ])
        self.conv_3 = nn.Sequential(*[
            projection_conv(inter_channels, nr, scale, up),
            nn.PReLU(nr)
        ])

    def forward(self, x):
        if self.bottleneck is not None:
            x = self.bottleneck(x)

        a_0 = self.conv_1(x)
        b_0 = self.conv_2(a_0)
        e = b_0.sub(x)
        a_1 = self.conv_3(e)

        out = a_0.add(a_1)

        return out

class DDBPN(nn.Module):
    def __init__(self, args):
        super(DDBPN, self).__init__()
        scale = args.scale[0]

        n0 = 128
        nr = 32
        self.depth = 6

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        initial = [
            nn.Conv2d(args.n_colors, n0, 3, padding=1),
            nn.PReLU(n0),
            nn.Conv2d(n0, nr, 1),
            nn.PReLU(nr)
        ]
        self.initial = nn.Sequential(*initial)

        self.upmodules = nn.ModuleList()
        self.downmodules = nn.ModuleList()
        channels = nr
        for i in range(self.depth):
            self.upmodules.append(
                DenseProjection(channels, nr, scale, True, i > 1)
            )
            if i != 0:
                channels += nr
        
        channels = nr
        for i in range(self.depth - 1):
            self.downmodules.append(
                DenseProjection(channels, nr, scale, False, i != 0)
            )
            channels += nr

        reconstruction = [
            nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 
        ]
        self.reconstruction = nn.Sequential(*reconstruction)

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.initial(x)

        h_list = []
        l_list = []
        for i in range(self.depth - 1):
            if i == 0:
                l = x
            else:
                l = torch.cat(l_list, dim=1)
            h_list.append(self.upmodules[i](l))
            l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))
        
        h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))
        out = self.reconstruction(torch.cat(h_list, dim=1))
        out = self.add_mean(out)

        return out



================================================
FILE: src/model/edsr.py
================================================
from model import common

import torch.nn as nn

url = {
    'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
    'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
    'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
    'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
    'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
    'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}

def make_model(args, parent=False):
    return EDSR(args)

class EDSR(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(EDSR, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3 
        scale = args.scale[0]
        act = nn.ReLU(True)
        url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
        if url_name in url:
            self.url = url[url_name]
        else:
            self.url = None
        self.sub_mean = common.MeanShift(args.rgb_range)
        self.add_mean = common.MeanShift(args.rgb_range, sign=1)

        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x)

        return x 

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))



================================================
FILE: src/model/han.py
================================================
from model import common
import torch
import torch.nn as nn
import pdb

def make_model(args, parent=False):
    return HAN(args)

## Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y

class LAM_Module(nn.Module):
    """ Layer attention module"""
    def __init__(self, in_dim):
        super(LAM_Module, self).__init__()
        self.chanel_in = in_dim


        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X N X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X N X N
        """
        m_batchsize, N, C, height, width = x.size()
        proj_query = x.view(m_batchsize, N, -1)
        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, N, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, N, C, height, width)

        out = self.gamma*out + x
        out = out.view(m_batchsize, -1, height, width)
        return out

class CSAM_Module(nn.Module):
    """ Channel-Spatial attention module"""
    def __init__(self, in_dim):
        super(CSAM_Module, self).__init__()
        self.chanel_in = in_dim


        self.conv = nn.Conv3d(1, 1, 3, 1, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        #self.softmax  = nn.Softmax(dim=-1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X N X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X N X N
        """
        m_batchsize, C, height, width = x.size()
        out = x.unsqueeze(1)
        out = self.sigmoid(self.conv(out))
        
        # proj_query = x.view(m_batchsize, N, -1)
        # proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)
        # energy = torch.bmm(proj_query, proj_key)
        # energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        # attention = self.softmax(energy_new)
        # proj_value = x.view(m_batchsize, N, -1)

        # out = torch.bmm(attention, proj_value)
        # out = out.view(m_batchsize, N, C, height, width)

        out = self.gamma*out
        out = out.view(m_batchsize, -1, height, width)
        x = x * out + x
        return x

## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(
        self, conv, n_feat, kernel_size, reduction,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(RCAB, self).__init__()
        modules_body = []
        for i in range(2):
            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
            if bn: modules_body.append(nn.BatchNorm2d(n_feat))
            if i == 0: modules_body.append(act)
        modules_body.append(CALayer(n_feat, reduction))
        self.body = nn.Sequential(*modules_body)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x)
        #res = self.body(x).mul(self.res_scale)
        res += x
        return res

## Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
        super(ResidualGroup, self).__init__()
        modules_body = []
        modules_body = [
            RCAB(
                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
            for _ in range(n_resblocks)]
        modules_body.append(conv(n_feat, n_feat, kernel_size))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res

## Holistic Attention Network (HAN)
class HAN(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(HAN, self).__init__()
        
        n_resgroups = args.n_resgroups
        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        reduction = args.reduction 
        scale = args.scale[0]
        act = nn.ReLU(True)
        
        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        
        # define head module
        modules_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        modules_body = [
            ResidualGroup(
                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
            for _ in range(n_resgroups)]

        modules_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        modules_tail = [
            common.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)]

        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.csa = CSAM_Module(n_feats)
        self.la = LAM_Module(n_feats)
        self.last_conv = nn.Conv2d(n_feats*11, n_feats, 3, 1, 1)
        self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        res = x
        #pdb.set_trace()
        for name, midlayer in self.body._modules.items():
            res = midlayer(res)
            #print(name)
            if name=='0':
                res1 = res.unsqueeze(1)
            else:
                res1 = torch.cat([res.unsqueeze(1),res1],1)
        #res = self.body(x)
        out1 = res
        #res3 = res.unsqueeze(1)
        #res = torch.cat([res1,res3],1)
        res = self.la(res1)
        out2 = self.last_conv(res)

        out1 = self.csa(out1)
        out = torch.cat([out1, out2], 1)
        res = self.last(out)
        
        res += x
        #res = self.csa(res)

        x = self.tail(res)
        x = self.add_mean(x)

        return x 

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

================================================
FILE: src/model/matrixmodel.py
================================================
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao (Bin.Xiao@microsoft.com)
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from model import ops
import pdb


try:
    from model.dcn.deform_conv import ModulatedDeformConvPack as DCN
except ImportError:
    raise ImportError('Failed to import DCNv2 module.')

BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)

def initialize_weights(net_l, scale=1):
    if not isinstance(net_l, list):
        net_l = [net_l]
    for net in net_l:
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale  # for residual block
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias.data, 0.0)

class ResBlock(nn.Module):
    def __init__(
        self, num_channels, kernel_size=3,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1,**kwargs):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=bias))
            if bn: m.append(nn.BatchNorm2d(num_channels))
            if i == 0: m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale
        initialize_weights([self.body], 0.1)

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res

class BFN(nn.Module):
    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):
        super(BFN, self).__init__()

        branch1=[]
        branch1.append(self._make_blocks(num_channels[0], num_channels[0], kernel_size, reduction, n_blocks, block))
        branch1.append(nn.Conv2d(num_channels[0], num_channels[0], kernel_size, stride=1, padding=1, bias=True))
        branch2=[]
        branch2.append(self._make_blocks(num_channels[1], num_channels[1], kernel_size, reduction, n_blocks, block))
        branch2.append(nn.Conv2d(num_channels[1], num_channels[1], kernel_size, stride=1, padding=1, bias=True))
        branch3=[]
        branch3.append(self._make_blocks(num_channels[2], num_channels[2], kernel_size, reduction, n_blocks, block))
        branch3.append(nn.Conv2d(num_channels[2], num_channels[2], kernel_size, stride=1, padding=1, bias=True))
        self.branch1 = nn.Sequential(*branch1)
        self.branch2 = nn.Sequential(*branch2)
        self.branch3 = nn.Sequential(*branch3)
        #self.act=nn.ReLU(True)


    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        blocks = []
        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \
            for _ in range(n_blocks)]
        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))
        
        return nn.Sequential(*blocks)

    def forward(self, x):
        assert type(x) is tuple and len(x)==3
        #branch1
        res1 = x[0]
        out1 = self.branch1(x[0])
        out1 += res1

        #branch2
        res2 = x[1]
        out2 = self.branch2(x[1])
        out2 += res2

        #branch3
        res3 = x[2]
        out3 = self.branch3(x[2])
        out3 += res3

        return (out1,out2,out3)

class BFN1(nn.Module):
    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):
        super(BFN1, self).__init__()

        branch1=[]
        branch1.append(self._make_blocks(num_channels, num_channels, kernel_size, reduction, n_blocks, block))
        branch1.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))
        self.branch1 = nn.Sequential(*branch1)
        #self.act=nn.ReLU(True)


    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        blocks = []
        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \
            for _ in range(n_blocks)]
        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))
        
        return nn.Sequential(*blocks)

    def forward(self, x):
        #branch1
        res1 = x
        out1 = self.branch1(x)
        out1 += res1

        return out1

class BFN2(nn.Module):
    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):
        super(BFN2, self).__init__()

        branch1=[]
        branch1.append(self._make_blocks(num_channels[0], num_channels[0], kernel_size, reduction, n_blocks, block))
        branch1.append(nn.Conv2d(num_channels[0], num_channels[0], kernel_size, stride=1, padding=1, bias=True))
        branch2=[]
        branch2.append(self._make_blocks(num_channels[1], num_channels[1], kernel_size, reduction, n_blocks, block))
        branch2.append(nn.Conv2d(num_channels[1], num_channels[1], kernel_size, stride=1, padding=1, bias=True))
        self.branch1 = nn.Sequential(*branch1)
        self.branch2 = nn.Sequential(*branch2)
        #self.act=nn.ReLU(True)


    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        blocks = []
        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \
            for _ in range(n_blocks)]
        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))
        
        return nn.Sequential(*blocks)

    def forward(self, x):
        assert type(x) is tuple and len(x)==2
        #branch1
        res1 = x[0]
        out1 = self.branch1(x[0])
        out1 += res1

        #branch2
        res2 = x[1]
        out2 = self.branch2(x[1])
        out2 += res2

        return (out1,out2)

class EoctResBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, num_channels, stride=1, downsample=None, res_scale=1, **kwargs):
        super(EoctResBlock, self).__init__()
        self.num_channels = num_channels # (64,64,64)
        self.stride = stride
        self.downsample = downsample
        self.res_scale = res_scale
        self.conv1 = ops.EoctConv(in_channels, num_channels, stride=stride)
        self.conv2 = ops.EoctConv(num_channels, num_channels)

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        #out = ops.bn(out, self.num_channels)
        out = ops.relu(out)

        out = self.conv2(out)
        #out = ops.bn(out, self.num_channels)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        #out = out * self.res_scale + residual
        out = ops.tupleSum(out,residual)
        #pdb.set_trace()
        out = ops.relu(out)

        return out

class EoctBottleneck(nn.Module):
    def __init__(self, in_channels, num_channels, stride=1, downsample=None, res_scale=1, **kwargs):
        super(EoctBottleneck, self).__init__()
        self.num_channels = num_channels
        self.stride = stride
        self.downsample = downsample
        self.res_scale = res_scale
        expand = 6
        linear = 0.8
        self.conv1 = ops.EoctConv(in_channels, ops.tupleMultiply(num_channels,expand), kernel_size=1, padding=1//2)
        #self.bn1 = nn.BatchNorm2d(num_channels*expand, momentum=BN_MOMENTUM)
        self.conv2 = ops.EoctConv(ops.tupleMultiply(num_channels,expand), int(ops.tupleMultiply(num_channels,linear)), kernel_size=1, padding=1//2)
        self.conv3 = ops.EoctConv(int(ops.tupleMultiply(num_channels,linear)), num_channels, kernel_size=3, padding=kernel_size//2)
    
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        #out = ops.bn(out, self.num_channels)
        out = ops.relu(out)

        out = self.conv2(out)
        #out = ops.bn(out, self.num_channels)
        
        out = self.conv3(out)
        #out = ops.bn(out, self.num_channels)
        
        if self.downsample is not None:
            residual = self.downsample(x)

        #out = out * self.res_scale + residual
        out = ops.tupleSum(out,residual)
        out = ops.relu(out)

        return out
        

class CALayer(nn.Module):
    def __init__(self, in_channels, num_channels, reduction=16):
        super(CALayer, self).__init__()
        
        # feature channel downscale and upscale --> channel weight
        self.conv1 = ops.EoctConv(in_channels, num_channels // reduction, 1, padding=0, bias=True),
        self.conv2 = ops.EoctConv(num_channels // reduction, num_channels, 1, padding=0, bias=True),


    def forward(self, x):
    
        out = ops.avg_pool2d(x)
        
        out = self.conv1(out)
        out = ops.relu(out)
        out = self.conv2(out)
        out = ops.sigmoid(out)
        
        return x * out

class CAEoctResBlock(nn.Module):
    def __init__(self, in_channels, num_channels, reduction, bias=True, res_scale=1, **kwargs):
        super(CAEoctResBlock, self).__init__()
        self.num_channels = num_channels # [64,64,64,64]
        self.res_scale = res_scale
        self.conv1 = ops.EoctConv(in_channels, num_channels, stride=stride)
        self.conv2 = ops.EoctConv(num_channels, num_channels)
        self.caLayer = CAEctBlock(num_channels, num_channels, reduction)
        
    def forward(self, x):
        res = x
        
        out = self.conv1(x)
        out = ops.relu(out)
        out = self.conv2(out)
        
        out = self.caLayer(out)
        
        out = ops.tupleSum(out,res)
        #out = out * self.res_scale + res
        out = ops.relu(out)
        
        
        return out

blocks_dict = {
    'BASIC':ResBlock,
    'EctBASIC': EoctResBlock,
    'EctBOTTLENECK': EoctBottleneck,
    'CAEctBASIC':CAEoctResBlock
}

def make_model(args, parent=False):
    return MatrixModelG2(args)

class MatrixModel(nn.Module):
    def __init__(self, args):
        super(MatrixModel, self).__init__()
        
        n_groups = args.n_resgroups
        n_blocks = args.n_resblocks
        num_channels = (64, 64, 64)
        kernel_size = 3
        reduction = args.reduction 
        scale = args.scale
        block = EoctResBlock
        
        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
        
        self.first_conv = ops.EoctConv(3, 64)
        
        modules_body1 = []
        modules_body1.append(self._make_blocks(64, 64, kernel_size, reduction, n_blocks, block))
        modules_body1.append(ops.EoctConv(64, (64,64), kernel_size))
        
        modules_body2 = []
        modules_body2.append(self._make_blocks((64,64), (64,64), kernel_size, reduction, n_blocks, block))
        modules_body2.append(ops.EoctConv((64,64), num_channels, kernel_size))
        
        modules_body3 = []
        modules_body3.append(self._make_blocks(num_channels, num_channels, kernel_size, reduction, n_blocks, block))
        modules_body3.append(ops.EoctConv(num_channels, 64, kernel_size))
        
        modules_tail = [
            ops._UpsampleBlock(num_channels[0], scale=scale),
            nn.Conv2d(num_channels[0], 3, kernel_size, 1, 1)]
        
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)
        
    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        blocks = []
        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \
            for _ in range(n_blocks)]
        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))
        
        return nn.Sequential(*blocks)
        
    def forward(self, x):
        
        x = self.sub_mean(x)
        x = self.first_conv(x)

        res = x
        x = self.body1(x)
        x = self.body2(x)
        x = self.body3(x)
        x += res
        #pdb.set_trace()

        out = self.tail(x)
        out = self.add_mean(out)
        
        return out

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

class RERB(nn.Module):
    def __init__(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        super(RERB, self).__init__()

        blocks = []
        blocks.append(self._make_blocks(in_channels, num_channels, kernel_size, reduction, n_blocks, block))
        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))
        self.body = nn.Sequential(*blocks)

    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):
        blocks = []
        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \
            for _ in range(n_blocks)]
        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))
        
        return nn.Sequential(*blocks)

    def forward(self, x):
        res = x
        x = self.body(x)
        x = ops.tupleSum(x,res)
        x = ops.relu(x)

        return x


class MatrixModelB(nn.Module):
    def __init__(self, args):
        super(MatrixModelB, self).__init__()
        
        num_channels = (64, 64, 64)
        kernel_size = 3
        reduction = args.reduction 
        scale = args.scale
        block = blocks_dict[args.block]
        
        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
        
        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)
        
        modules_stage1 = []
        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))
        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))
        self.stage1 = nn.Sequential(*modules_stage1)
        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)

        modules_stage2 = []
        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))
        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))
        self.stage2 = nn.Sequential(*modules_stage2)
        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)

        modules_stage3 = []
        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))
        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))
        self.stage3 = nn.Sequential(*modules_stage3)
        self.stage3_conv = ops.EoctConv(num_channels, num_channels, kernel_size)
        
        '''
        modules_stage4 = []
        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))
        self.stage4 = nn.Sequential(*modules_stage4)
        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)

        
        modules_body = []
        for i in range(n_groups):
            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))
        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))
        '''
        self.fusion_conv1 = ops.EoctConv(num_channels, num_channels, kernel_size)
        self.fusion_conv2 = ops.EoctConv(num_channels, num_channels, kernel_size)
        self.fusion_conv3 = ops.EoctConv(num_channels, num_channels, kernel_size)
        self.conv_last = ops.EoctConv(num_channels, 64, kernel_size)
        
        modules_tail1 = [
            ops._UpsampleBlock(64, scale=scale),
            nn.Conv2d(64, 3, kernel_size, 1, 1)]
        
        #self.body = nn.Sequential(*modules_body)
        self.tail1 = nn.Sequential(*modules_tail1)
        '''
        modules_tail2 = [
            ops._UpsampleBlock(64, scale=scale),
            nn.Conv2d(64, 3, kernel_size, 1, 1)]
        
        #self.body = nn.Sequential(*modules_body)
        self.tail2 = nn.Sequential(*modules_tail2)
        
        modules_tail3 = [
            ops._UpsampleBlock(64, scale=scale),
            nn.Conv2d(64, 3, kernel_size, 1, 1)]
        
        #self.body = nn.Sequential(*modules_body)
        self.tail3 = nn.Sequential(*modules_tail3)
        '''
              
    def forward(self, x):
        
        x = self.sub_mean(x)
        x = self.first_conv(x)
        residual = x
        #pdb.set_trace()

        #stage1
        x = self.stage1(x)
        x = self.stage1_conv(x)
        #pdb.set_trace()
        L1_fea = x[0]

        #stage2
        x = self.stage2(x)
        x = self.stage2_conv(x)
        L2_fea = x[1]

        #stage3
        x = self.stage3(x)
        out = self.stage3_conv(x)
        L3_fea = x[2]
        
        #stage4
        #x = self.stage4(x)
        #x = self.stage4_conv(x)

        x = (L1_fea, L2_fea, L3_fea)
        res1 = x
        x = self.fusion_conv1(x)
        x = ops.tupleSum(x,res1)
        res2 = x
        x = self.fusion_conv2(x)
        x = ops.tupleSum(x,res2)
        res3 = x
        x = self.fusion_conv3(x)
        x = ops.tupleSum(x,res3)
        out = self.conv_last(x)
        out += residual

        out = self.tail1(out)
        out = self.add_mean(out)

        #out2 = self.tail1(x[1])
        #out2 = self.add_mean(out2)

        #out3 = self.tail2(x[2])
        #out3 = self.add_mean(out3)
        #pdb.set_trace()
        
        return out

    def load_state_dict(self, state_dict, strict=False):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') >= 0:
                        print('Replace pre-trained upsampler to new one...')
                    else:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

        if strict:
            missing = set(own_state.keys()) - set(state_dict.keys())
            if len(missing) > 0:
                raise KeyError('missing keys in state_dict: "{}"'.format(missing))

class PDF(nn.Module):
    ''' Alignment module using Pyramid, Deformable convolution and Fusion.
    with 3 pyramid levels.
    Bottom-Up.
    '''

    def __init__(self, nf=64, groups=8):
        super(PDF, self).__init__()
        # L1: level 1, original spatial size
        #self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff
        self.L1_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        # L2: level 2, 1/2 spatial size
        self.L2_offset_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)  # concat for diff
        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset
        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea
        # L3: level 3, 1/4 spatial size
        self.L3_offset_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)  # concat for diff
        self.L3_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset
        self.L3_offset_conv3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
        self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        self.L3_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea

        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        self.conv_last = nn.Conv2d(nf * 3, nf, 3, 1, 1, bias=True)

    def forward(self, nbr_fea_l):
        '''align other neighboring frames to the reference frame in the feature level
        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
        '''
        # L1
        L1_offset = nbr_fea_l[0]
        #L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
        L1_offset = self.lrelu(self.L1_offset_conv2(L1_offset))
        L1_fea = self.lrelu(self.L1_dcnpack([nbr_fea_l[0], L1_offset]))
        L1_f = L1_fea
        # L2
        L2_offset = nbr_fea_l[1]
        L1_offset = self.lrelu(self.L2_offset_conv1(L1_offset))
        #L1_offset = F.interpolate(L1_offset, scale_factor=1/2, mode='bilinear', align_corners=False)
        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L1_offset * 2], dim=1)))
        #L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
        L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
        L1_fea = self.lrelu(self.L2_offset_conv3(L1_fea))
        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L1_fea], dim=1)))
        L2_f = L2_fea
        # L3
        L3_offset = nbr_fea_l[2]
        #L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
        L2_offset = self.L3_offset_conv1(L2_offset)
        L3_offset = self.lrelu(self.L3_offset_conv2(torch.cat([L3_offset, L2_offset * 2], dim=1)))
        #L3_offset = self.lrelu(self.L3_offset_conv3(L3_offset))
        L3_fea = self.L3_dcnpack([nbr_fea_l[2], L3_offset])
        L2_fea = self.lrelu(self.L3_offset_conv3(L2_fea))
        L3_fea = self.L3_fea_conv(torch.cat([L3_fea, L2_fea], dim=1))
        # Fusion
        L3_fea = self.upsample2(L3_fea)
        L2_f = self.upsample(L2_f)
        L_fea = torch.cat([torch.cat([L1_f, L2_f], dim=1),L3_fea],dim=1)
        L_fea = self.lrelu(self.conv_last(L_fea))
        return L_fea

class PD(nn.Module):
    ''' module using Pyramid, Deformable convolution
    with 3 pyramid levels.
    Top-down.
    '''

    def __init__(self, nf=64, groups=8):
        super(PD, self).__init__()
        # L3: level 3, 1/4 spatial size
        #self.L3_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff
        self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        # L2: level 2, 1/2 spatial size
        #self.L2_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff
        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset
        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea
        # L1: level 1, original spatial size
        #self.L1_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff
        self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset
        self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
                              extra_offset_mask=True)
        self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea
        # Cascading DCN
        #self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff
        #self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        #self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,extra_offset_mask=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        #self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
        #self.conv_last = nn.Conv2d(nf * 3, nf, 3, 1, 1, bias=True)

    def forward(self, nbr_fea_l):
        '''align other neighboring frames to the reference frame in the feature level
        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
        '''
        # L3
        L3_offset = nbr_fea_l[2]
        #L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
        L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
        L3_f = L3_fea
        # L2
        L2_offset = nbr_fea_l[1]
        #L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
        L3_offset = self.upsample(L3_offset)
        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
        L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
        L3_fea = self.upsample(L3_fea)
        #pdb.set_trace()
        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
        L2_f = L2_fea
        # L1
        L1_offset = nbr_fea_l[0]
        #L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
        L2_offset = self.upsample(L2_offset)
        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
        L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
        L2_fea = self.upsample(L2_fea)
        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
        # Cascading
        #offset = L1_fea
        #offset = self.lrelu(self.cas_offset_conv1(offset))
        #offset = self.lrelu(self.cas_offset_conv2(offset))
        #L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
        
        #L3_f = self.upsample2(L3_f)
        #L2_f = self.upsample(L2_f)
        #L_fea = torch.cat([torch.cat([L1_fea, L2_f], dim=1),L3_f],dim=1)
        #L_fea = self.lrelu(self.conv_last(L_fea))

        return (L1_fea, L2_f, L3_f)

class MatrixModelC(nn.Module):
    def __init__(self, args):
        super(MatrixModelC, self).__init__()
        
        num_channels = (64, 64, 64)
        kernel_size = 3
        reduction = args.reduction 
        scale = args.scale
        block = blocks_dict[args.block]
        
        # RGB mean for DIV2K
        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)
        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
        
        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)
        
        modules_stage1 = []
        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))
        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))
        self.stage1 = nn.Sequential(*modules_stage1)
        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)

        modules_stage2 = []
        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))
        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))
        self.stage2 = nn.Sequential(*modules_stage2)
        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)

        modules_stage3 = []
        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))
        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))
        self.stage3 = nn.Sequential(*modules_stage3)
        self.stage3_conv = ops.EoctConv(num_channels, num_channels, kernel_size)
        
        '''
        modules_stage4 = []
        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))
        self.stage4 = nn.Sequential(*modules_stage4)
        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)
        '''
        
        self.pd = PD()
        self.pdf = PDF()
        modules_tail1 = [
            ops._UpsampleBlock(64, scale=scale),
            nn.Conv2d(64, 3, kernel_size, 1, 1)]
        self.tail1 = nn.Sequential(*modules_tail1)
        
    def forward(self, x):
        
        x = self.sub_mean(x)
        x = self.first_conv(x)
        residual = x
        #pdb.set_trace()

        #stage1
        x = self.stage1(x)
        x = self.stage1_conv(x)
        #pdb.set_trace()
        L1_fea = x[0]

        #stage2
        x = self.stage2(x)
        x = self.stage2_conv(x)
        L2_fea = x[1]

        #stage3
        x = self.stage3(x)
        x = self.stage3_conv(x)
        L3_f
Download .txt
gitextract_zj0t_fxr/

├── .gitignore
├── LICENSE
├── README.md
├── experiment/
│   └── .gitignore
└── src/
    ├── __init__.py
    ├── data/
    │   ├── __init__.py
    │   ├── benchmark.py
    │   ├── common.py
    │   ├── demo.py
    │   ├── div2k.py
    │   ├── div2kjpeg.py
    │   ├── sr291.py
    │   ├── srdata.py
    │   └── video.py
    ├── dataloader.py
    ├── demo.sh
    ├── loss/
    │   ├── __init__.py
    │   ├── adversarial.py
    │   ├── discriminator.py
    │   └── vgg.py
    ├── main.py
    ├── model/
    │   ├── __init__.py
    │   ├── common.py
    │   ├── dcn/
    │   │   ├── __init__.py
    │   │   ├── deform_conv.py
    │   │   ├── setup.py
    │   │   └── src/
    │   │       ├── deform_conv_cuda.cpp
    │   │       └── deform_conv_cuda_kernel.cu
    │   ├── ddbpn.py
    │   ├── edsr.py
    │   ├── han.py
    │   ├── matrixmodel.py
    │   ├── mdsr.py
    │   ├── ops.py
    │   ├── rcan.py
    │   ├── rcan1.py
    │   ├── rcan3.py
    │   ├── rcan4.py
    │   ├── rdn.py
    │   ├── rdn1.py
    │   ├── rdn2.py
    │   └── vdsr.py
    ├── option.py
    ├── template.py
    ├── trainer.py
    ├── utility.py
    └── videotester.py
Download .txt
SYMBOL INDEX (461 symbols across 38 files)

FILE: src/data/__init__.py
  class MyConcatDataset (line 7) | class MyConcatDataset(ConcatDataset):
    method __init__ (line 8) | def __init__(self, datasets):
    method set_scale (line 12) | def set_scale(self, idx_scale):
  class Data (line 16) | class Data:
    method __init__ (line 17) | def __init__(self, args):

FILE: src/data/benchmark.py
  class Benchmark (line 13) | class Benchmark(srdata.SRData):
    method __init__ (line 14) | def __init__(self, args, name='', train=True, benchmark=True):
    method _scan (line 18) | def _scan(self):
    method _set_filesystem (line 39) | def _set_filesystem(self, dir_data):

FILE: src/data/common.py
  function get_patch (line 8) | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=Fa...
  function set_channel (line 34) | def set_channel(*args, n_channels=3):
  function np2Tensor (line 49) | def np2Tensor(*args, rgb_range=255):
  function augment (line 59) | def augment(*args, hflip=True, rot=True):

FILE: src/data/demo.py
  class Demo (line 11) | class Demo(data.Dataset):
    method __init__ (line 12) | def __init__(self, args, name='Demo', train=False, benchmark=False):
    method __getitem__ (line 26) | def __getitem__(self, idx):
    method __len__ (line 34) | def __len__(self):
    method set_scale (line 37) | def set_scale(self, idx_scale):

FILE: src/data/div2k.py
  class DIV2K (line 4) | class DIV2K(srdata.SRData):
    method __init__ (line 5) | def __init__(self, args, name='DIV2K', train=True, benchmark=False):
    method _scan (line 20) | def _scan(self):
    method _set_filesystem (line 27) | def _set_filesystem(self, dir_data):

FILE: src/data/div2kjpeg.py
  class DIV2KJPEG (line 5) | class DIV2KJPEG(div2k.DIV2K):
    method __init__ (line 6) | def __init__(self, args, name='', train=True, benchmark=False):
    method _set_filesystem (line 12) | def _set_filesystem(self, dir_data):

FILE: src/data/sr291.py
  class SR291 (line 3) | class SR291(srdata.SRData):
    method __init__ (line 4) | def __init__(self, args, name='SR291', train=True, benchmark=False):

FILE: src/data/srdata.py
  class SRData (line 15) | class SRData(data.Dataset):
    method __init__ (line 16) | def __init__(self, args, name='', train=True, benchmark=False):
    method _scan (line 71) | def _scan(self):
    method _set_filesystem (line 87) | def _set_filesystem(self, dir_data):
    method _check_and_load (line 94) | def _check_and_load(self, ext, img, f, verbose=True):
    method __getitem__ (line 101) | def __getitem__(self, idx):
    method __len__ (line 109) | def __len__(self):
    method _get_index (line 115) | def _get_index(self, idx):
    method _load_file (line 121) | def _load_file(self, idx):
    method get_patch (line 140) | def get_patch(self, lr, hr):
    method set_scale (line 158) | def set_scale(self, idx_scale):

FILE: src/data/video.py
  class Video (line 12) | class Video(data.Dataset):
    method __init__ (line 13) | def __init__(self, args, name='Video', train=False, benchmark=False):
    method __getitem__ (line 27) | def __getitem__(self, idx):
    method __len__ (line 39) | def __len__(self):
    method set_scale (line 42) | def set_scale(self, idx_scale):

FILE: src/dataloader.py
  function _ms_loop (line 22) | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, s...
  class _MSDataLoaderIter (line 68) | class _MSDataLoaderIter(_DataLoaderIter):
    method __init__ (line 70) | def __init__(self, loader):
  class MSDataLoader (line 148) | class MSDataLoader(DataLoader):
    method __init__ (line 150) | def __init__(self, cfg, *args, **kwargs):
    method __iter__ (line 156) | def __iter__(self):

FILE: src/loss/__init__.py
  class Loss (line 14) | class Loss(nn.modules.loss._Loss):
    method __init__ (line 15) | def __init__(self, args, ckp):
    method forward (line 69) | def forward(self, sr, hr):
    method step (line 86) | def step(self):
    method start_log (line 91) | def start_log(self):
    method end_log (line 94) | def end_log(self, n_batches):
    method display_loss (line 97) | def display_loss(self, batch):
    method plot_loss (line 105) | def plot_loss(self, apath, epoch):
    method get_loss_module (line 119) | def get_loss_module(self):
    method save (line 125) | def save(self, apath):
    method load (line 129) | def load(self, apath, cpu=False):

FILE: src/loss/adversarial.py
  class Adversarial (line 12) | class Adversarial(nn.Module):
    method __init__ (line 13) | def __init__(self, args, gan_type):
    method forward (line 35) | def forward(self, fake, real):
    method state_dict (line 95) | def state_dict(self, *args, **kwargs):
    method bce (line 101) | def bce(self, real, fake):

FILE: src/loss/discriminator.py
  class Discriminator (line 5) | class Discriminator(nn.Module):
    method __init__ (line 9) | def __init__(self, args):
    method forward (line 50) | def forward(self, x):

FILE: src/loss/vgg.py
  class VGG (line 8) | class VGG(nn.Module):
    method __init__ (line 9) | def __init__(self, conv_index, rgb_range=1):
    method forward (line 24) | def forward(self, sr, hr):

FILE: src/main.py
  function main (line 13) | def main():

FILE: src/model/__init__.py
  class Model (line 10) | class Model(nn.Module):
    method __init__ (line 11) | def __init__(self, args, ckp):
    method forward (line 39) | def forward(self, x, idx_scale):
    method save (line 60) | def save(self, apath, epoch, is_best=False):
    method load (line 73) | def load(self, apath, pre_train='', resume=-1, cpu=False):
    method forward_chop (line 106) | def forward_chop(self, x, shave=10, min_size=160000):
    method forward_x8 (line 147) | def forward_x8(self, *args, forward_function=None):

FILE: src/model/common.py
  function default_conv (line 7) | def default_conv(in_channels, out_channels, kernel_size, bias=True):
  class MeanShift (line 12) | class MeanShift(nn.Conv2d):
    method __init__ (line 13) | def __init__(
  class BasicBlock (line 24) | class BasicBlock(nn.Sequential):
    method __init__ (line 25) | def __init__(
  class ResBlock (line 37) | class ResBlock(nn.Module):
    method __init__ (line 38) | def __init__(
    method forward (line 54) | def forward(self, x):
  class Upsampler (line 60) | class Upsampler(nn.Sequential):
    method __init__ (line 61) | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

FILE: src/model/dcn/deform_conv.py
  class DeformConvFunction (line 15) | class DeformConvFunction(Function):
    method forward (line 17) | def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=...
    method backward (line 51) | def backward(ctx, grad_output):
    method _output_size (line 82) | def _output_size(input, weight, padding, dilation, stride):
  class ModulatedDeformConvFunction (line 97) | class ModulatedDeformConvFunction(Function):
    method forward (line 99) | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, pad...
    method backward (line 124) | def backward(ctx, grad_output):
    method _infer_shape (line 145) | def _infer_shape(ctx, input, weight):
  class DeformConv (line 161) | class DeformConv(nn.Module):
    method __init__ (line 162) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
    method reset_parameters (line 188) | def reset_parameters(self):
    method forward (line 195) | def forward(self, x, offset):
  class DeformConvPack (line 200) | class DeformConvPack(DeformConv):
    method __init__ (line 201) | def __init__(self, *args, **kwargs):
    method init_offset (line 211) | def init_offset(self):
    method forward (line 215) | def forward(self, x):
  class ModulatedDeformConv (line 221) | class ModulatedDeformConv(nn.Module):
    method __init__ (line 222) | def __init__(self, in_channels, out_channels, kernel_size, stride=1, p...
    method reset_parameters (line 243) | def reset_parameters(self):
    method forward (line 252) | def forward(self, x, offset, mask):
  class ModulatedDeformConvPack (line 258) | class ModulatedDeformConvPack(ModulatedDeformConv):
    method __init__ (line 259) | def __init__(self, *args, extra_offset_mask=False, **kwargs):
    method init_offset (line 270) | def init_offset(self):
    method forward (line 274) | def forward(self, x):

FILE: src/model/dcn/setup.py
  function make_cuda_ext (line 5) | def make_cuda_ext(name, sources):

FILE: src/model/dcn/src/deform_conv_cuda.cpp
  function shape_check (line 61) | void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOu...
  function deform_conv_forward_cuda (line 151) | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
  function deform_conv_backward_input_cuda (line 260) | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
  function deform_conv_backward_parameters_cuda (line 373) | int deform_conv_backward_parameters_cuda(
  function modulated_deform_conv_cuda_forward (line 486) | void modulated_deform_conv_cuda_forward(
  function modulated_deform_conv_cuda_backward (line 566) | void modulated_deform_conv_cuda_backward(
  function PYBIND11_MODULE (line 681) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: src/model/ddbpn.py
  function make_model (line 10) | def make_model(args, parent=False):
  function projection_conv (line 13) | def projection_conv(in_channels, out_channels, scale, up=True):
  class DenseProjection (line 29) | class DenseProjection(nn.Module):
    method __init__ (line 30) | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):
    method forward (line 55) | def forward(self, x):
  class DDBPN (line 68) | class DDBPN(nn.Module):
    method __init__ (line 69) | def __init__(self, args):
    method forward (line 112) | def forward(self, x):

FILE: src/model/edsr.py
  function make_model (line 14) | def make_model(args, parent=False):
  class EDSR (line 17) | class EDSR(nn.Module):
    method __init__ (line 18) | def __init__(self, args, conv=common.default_conv):
    method forward (line 55) | def forward(self, x):
    method load_state_dict (line 67) | def load_state_dict(self, state_dict, strict=True):

FILE: src/model/han.py
  function make_model (line 6) | def make_model(args, parent=False):
  class CALayer (line 10) | class CALayer(nn.Module):
    method __init__ (line 11) | def __init__(self, channel, reduction=16):
    method forward (line 23) | def forward(self, x):
  class LAM_Module (line 28) | class LAM_Module(nn.Module):
    method __init__ (line 30) | def __init__(self, in_dim):
    method forward (line 37) | def forward(self,x):
  class CSAM_Module (line 60) | class CSAM_Module(nn.Module):
    method __init__ (line 62) | def __init__(self, in_dim):
    method forward (line 71) | def forward(self,x):
  class RCAB (line 99) | class RCAB(nn.Module):
    method __init__ (line 100) | def __init__(
    method forward (line 114) | def forward(self, x):
  class ResidualGroup (line 121) | class ResidualGroup(nn.Module):
    method __init__ (line 122) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
    method forward (line 132) | def forward(self, x):
  class HAN (line 138) | class HAN(nn.Module):
    method __init__ (line 139) | def __init__(self, args, conv=common.default_conv):
    method forward (line 181) | def forward(self, x):
    method load_state_dict (line 212) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/matrixmodel.py
  function initialize_weights (line 30) | def initialize_weights(net_l, scale=1):
  class ResBlock (line 49) | class ResBlock(nn.Module):
    method __init__ (line 50) | def __init__(
    method forward (line 65) | def forward(self, x):
  class BFN (line 71) | class BFN(nn.Module):
    method __init__ (line 72) | def __init__(self, num_channels, kernel_size, reduction, n_blocks, blo...
    method _make_blocks (line 90) | def _make_blocks(self, in_channels, num_channels, kernel_size, reducti...
    method forward (line 98) | def forward(self, x):
  class BFN1 (line 117) | class BFN1(nn.Module):
    method __init__ (line 118) | def __init__(self, num_channels, kernel_size, reduction, n_blocks, blo...
    method _make_blocks (line 128) | def _make_blocks(self, in_channels, num_channels, kernel_size, reducti...
    method forward (line 136) | def forward(self, x):
  class BFN2 (line 144) | class BFN2(nn.Module):
    method __init__ (line 145) | def __init__(self, num_channels, kernel_size, reduction, n_blocks, blo...
    method _make_blocks (line 159) | def _make_blocks(self, in_channels, num_channels, kernel_size, reducti...
    method forward (line 167) | def forward(self, x):
  class EoctResBlock (line 181) | class EoctResBlock(nn.Module):
    method __init__ (line 184) | def __init__(self, in_channels, num_channels, stride=1, downsample=Non...
    method forward (line 193) | def forward(self, x):
  class EoctBottleneck (line 213) | class EoctBottleneck(nn.Module):
    method __init__ (line 214) | def __init__(self, in_channels, num_channels, stride=1, downsample=Non...
    method forward (line 227) | def forward(self, x):
  class CALayer (line 250) | class CALayer(nn.Module):
    method __init__ (line 251) | def __init__(self, in_channels, num_channels, reduction=16):
    method forward (line 259) | def forward(self, x):
  class CAEoctResBlock (line 270) | class CAEoctResBlock(nn.Module):
    method __init__ (line 271) | def __init__(self, in_channels, num_channels, reduction, bias=True, re...
    method forward (line 279) | def forward(self, x):
  function make_model (line 302) | def make_model(args, parent=False):
  class MatrixModel (line 305) | class MatrixModel(nn.Module):
    method __init__ (line 306) | def __init__(self, args):
    method _make_blocks (line 344) | def _make_blocks(self, in_channels, num_channels, kernel_size, reducti...
    method forward (line 352) | def forward(self, x):
    method load_state_dict (line 369) | def load_state_dict(self, state_dict, strict=False):
  class RERB (line 395) | class RERB(nn.Module):
    method __init__ (line 396) | def __init__(self, in_channels, num_channels, kernel_size, reduction, ...
    method _make_blocks (line 404) | def _make_blocks(self, in_channels, num_channels, kernel_size, reducti...
    method forward (line 412) | def forward(self, x):
  class MatrixModelB (line 421) | class MatrixModelB(nn.Module):
    method __init__ (line 422) | def __init__(self, args):
    method forward (line 496) | def forward(self, x):
    method load_state_dict (line 548) | def load_state_dict(self, state_dict, strict=False):
  class PDF (line 574) | class PDF(nn.Module):
    method __init__ (line 580) | def __init__(self, nf=64, groups=8):
    method forward (line 607) | def forward(self, nbr_fea_l):
  class PD (line 643) | class PD(nn.Module):
    method __init__ (line 649) | def __init__(self, nf=64, groups=8):
    method forward (line 681) | def forward(self, nbr_fea_l):
  class MatrixModelC (line 724) | class MatrixModelC(nn.Module):
    method __init__ (line 725) | def __init__(self, args):
    method forward (line 774) | def forward(self, x):
    method load_state_dict (line 813) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelD (line 840) | class MatrixModelD(nn.Module):
    method __init__ (line 841) | def __init__(self, args):
    method forward (line 908) | def forward(self, x):
    method load_state_dict (line 949) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelE (line 976) | class MatrixModelE(nn.Module):
    method __init__ (line 977) | def __init__(self, args):
    method forward (line 1047) | def forward(self, x):
    method load_state_dict (line 1085) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelF (line 1112) | class MatrixModelF(nn.Module):
    method __init__ (line 1113) | def __init__(self, args):
    method forward (line 1186) | def forward(self, x):
    method load_state_dict (line 1224) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelG (line 1251) | class MatrixModelG(nn.Module):
    method __init__ (line 1252) | def __init__(self, args):
    method forward (line 1325) | def forward(self, x):
    method load_state_dict (line 1369) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelG2 (line 1395) | class MatrixModelG2(nn.Module):
    method __init__ (line 1396) | def __init__(self, args):
    method forward (line 1474) | def forward(self, x):
    method load_state_dict (line 1520) | def load_state_dict(self, state_dict, strict=False):
  class MatrixModelF2 (line 1546) | class MatrixModelF2(nn.Module):
    method __init__ (line 1547) | def __init__(self, args):
    method forward (line 1640) | def forward(self, x):
    method load_state_dict (line 1694) | def load_state_dict(self, state_dict, strict=False):
  class PAM_Module (line 1720) | class PAM_Module(nn.Module):
    method __init__ (line 1723) | def __init__(self, in_dim):
    method forward (line 1733) | def forward(self, x):
  class CAM_Module (line 1755) | class CAM_Module(nn.Module):
    method __init__ (line 1757) | def __init__(self, in_dim):
    method forward (line 1764) | def forward(self,x):
  class GAM_Module (line 1786) | class GAM_Module(nn.Module):
    method __init__ (line 1788) | def __init__(self, in_dim):
    method forward (line 1795) | def forward(self,x):
  class DAM_Module (line 1818) | class DAM_Module(nn.Module):
    method __init__ (line 1820) | def __init__(self, in_dim):
    method forward (line 1827) | def forward(self,x):
  class MatrixModelH (line 1850) | class MatrixModelH(nn.Module):
    method __init__ (line 1851) | def __init__(self, args):
    method forward (line 1928) | def forward(self, x):
    method load_state_dict (line 1972) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/mdsr.py
  function make_model (line 10) | def make_model(args, parent=False):
  class MDSR (line 13) | class MDSR(nn.Module):
    method __init__ (line 14) | def __init__(self, args, conv=common.default_conv):
    method forward (line 51) | def forward(self, x):
    method set_scale (line 65) | def set_scale(self, scale_idx):

FILE: src/model/ops.py
  class EoctConv (line 10) | class EoctConv(nn.Module):
    method __init__ (line 11) | def __init__(self, in_channels, num_channels, kernel_size=3, stride=1,...
    method forward (line 55) | def forward(self, data):
  function relu (line 126) | def relu(data):
  function sigmoid (line 146) | def sigmoid(data):
  function bn (line 165) | def bn(data, num_channels):
  function max_pool2d (line 182) | def max_pool2d(data, l=(2,2)):
  function avg_pool2d (line 201) | def avg_pool2d(data):
  function dropout (line 221) | def dropout(data, l):
  function dataSum (line 241) | def dataSum(a, b):
  function tupleSum (line 252) | def tupleSum(a,b):
  class MeanShift (line 256) | class MeanShift(nn.Conv2d):
    method __init__ (line 257) | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
  class _UpsampleBlock (line 266) | class _UpsampleBlock(nn.Module):
    method __init__ (line 267) | def __init__(self,
    method forward (line 288) | def forward(self, x):
  function tupleMultiply (line 302) | def tupleMultiply(a, b):

FILE: src/model/rcan.py
  function make_model (line 7) | def make_model(args, parent=False):
  class CALayer (line 11) | class CALayer(nn.Module):
    method __init__ (line 12) | def __init__(self, channel, reduction=16):
    method forward (line 24) | def forward(self, x):
  class Ada_conv (line 31) | class Ada_conv(nn.Module):
    method __init__ (line 32) | def __init__(self, in_channels, out_channels, kernel_size, bias=True, ...
    method forward (line 45) | def forward(self, x):
  class ResAda_conv (line 56) | class ResAda_conv(nn.Module):
    method __init__ (line 57) | def __init__(self, in_channels, out_channels, kernel_size, bias=True, ...
    method forward (line 70) | def forward(self, x):
  class RCAB (line 81) | class RCAB(nn.Module):
    method __init__ (line 82) | def __init__(
    method forward (line 97) | def forward(self, x):
  class ResidualGroup (line 104) | class ResidualGroup(nn.Module):
    method __init__ (line 105) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
    method forward (line 116) | def forward(self, x):
  class RCAN (line 122) | class RCAN(nn.Module):
    method __init__ (line 123) | def __init__(self, args, conv=common.default_conv):
    method forward (line 161) | def forward(self, x):
    method load_state_dict (line 173) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/rcan1.py
  function make_model (line 8) | def make_model(args, parent=False):
  class CALayer (line 12) | class CALayer(nn.Module):
    method __init__ (line 13) | def __init__(self, channel, reduction=16):
    method forward (line 25) | def forward(self, x):
  class Dis (line 30) | class Dis(nn.Module):
    method __init__ (line 31) | def __init__(self, loss_type='L1', batchsize=16):
    method forward (line 39) | def forward(self, x1, x2):
    method L1Loss (line 51) | def L1Loss(self, x1, x2):
    method L2Loss (line 56) | def L2Loss(self, x1, x2):
    method bit_product_sum (line 61) | def bit_product_sum(self, x, y):
    method cosine_similarity (line 65) | def cosine_similarity(self, x, y, norm=True):
  class FullConvRes (line 87) | class FullConvRes(nn.Module):
    method __init__ (line 89) | def __init__(self, out_channels=64, in_channels=64, K=9):
    method forward (line 108) | def forward(self,x):
  class FullConvRes1 (line 189) | class FullConvRes1(nn.Module):
    method __init__ (line 191) | def __init__(self, out_channels=64, in_channels=64, kernel_size=3):
    method forward (line 210) | def forward(self,x):
  class FullConv (line 273) | class FullConv(nn.Module):
    method __init__ (line 275) | def __init__(self, out_channels=64, in_channels=64, kernel_size=3):
    method forward (line 288) | def forward(self,x):
  class RCAB (line 328) | class RCAB(nn.Module):
    method __init__ (line 329) | def __init__(
    method forward (line 343) | def forward(self, x):
  class ResidualGroup (line 350) | class ResidualGroup(nn.Module):
    method __init__ (line 351) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
    method forward (line 362) | def forward(self, x):
  class RCAN (line 368) | class RCAN(nn.Module):
    method __init__ (line 369) | def __init__(self, args, conv=common.default_conv):
    method forward (line 409) | def forward(self, x):
    method load_state_dict (line 421) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/rcan3.py
  function make_model (line 9) | def make_model(args, parent=False):
  class CALayer (line 13) | class CALayer(nn.Module):
    method __init__ (line 14) | def __init__(self, channel, reduction=16):
    method forward (line 26) | def forward(self, x):
  class MSCALayer (line 31) | class MSCALayer(nn.Module):
    method __init__ (line 32) | def __init__(self):
  class Dis (line 35) | class Dis(nn.Module):
    method __init__ (line 36) | def __init__(self, loss_type='L1', B=4):
    method forward (line 44) | def forward(self, x1, x2):
    method L1Loss (line 56) | def L1Loss(self, x1, x2):
    method L2Loss (line 61) | def L2Loss(self, x1, x2):
    method bit_product_sum (line 66) | def bit_product_sum(self, x, y):
    method cosine_similarity (line 70) | def cosine_similarity(self, x, y, norm=True):
  class DAM_Module (line 91) | class DAM_Module(nn.Module):
    method __init__ (line 93) | def __init__(self, in_dim):
    method forward (line 100) | def forward(self,x):
  class SEDAM_Module (line 123) | class SEDAM_Module(nn.Module):
    method __init__ (line 125) | def __init__(self, in_dim):
    method forward (line 136) | def forward(self,x):
  class MSAM_Module (line 172) | class MSAM_Module(nn.Module):
    method __init__ (line 174) | def __init__(self, in_dim):
    method forward (line 195) | def forward(self,x):
    method attention (line 223) | def attention(self, x):
    method one_scale (line 228) | def one_scale(self, x, scale=2):
    method multi_scale (line 240) | def multi_scale(self, x):
  class SAM_Module (line 252) | class SAM_Module(nn.Module):
    method __init__ (line 254) | def __init__(self, in_dim):
    method forward (line 271) | def forward(self,x):
    method depixel_shuffle (line 313) | def depixel_shuffle(self, x, upscale_factor=2):
    method squaremax (line 334) | def squaremax(self, x, dim=-1):
    method logmax (line 340) | def logmax(self,x):
    method absmax (line 346) | def absmax(self,x):
  class SECAM_Module (line 352) | class SECAM_Module(nn.Module):
    method __init__ (line 353) | def __init__(self, in_dim):
    method forward (line 365) | def forward(self,x):
  class LAM_Module (line 402) | class LAM_Module(nn.Module):
    method __init__ (line 404) | def __init__(self, in_dim):
    method forward (line 413) | def forward(self,x):
  class GAM_Module (line 458) | class GAM_Module(nn.Module):
    method __init__ (line 461) | def __init__(self, in_dim):
    method forward (line 470) | def forward(self,x):
  class RCAB (line 498) | class RCAB(nn.Module):
    method __init__ (line 499) | def __init__(
    method forward (line 513) | def forward(self, x):
  class ResidualGroup (line 520) | class ResidualGroup(nn.Module):
    method __init__ (line 521) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
    method forward (line 532) | def forward(self, x):
  class RCAN (line 538) | class RCAN(nn.Module):
    method __init__ (line 539) | def __init__(self, args, conv=common.default_conv):
    method forward (line 582) | def forward(self, x):
    method load_state_dict (line 614) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/rcan4.py
  function make_model (line 6) | def make_model(args, parent=False):
  class CALayer (line 10) | class CALayer(nn.Module):
    method __init__ (line 11) | def __init__(self, channel, reduction=16):
    method forward (line 23) | def forward(self, x):
  class DAM_Module (line 28) | class DAM_Module(nn.Module):
    method __init__ (line 30) | def __init__(self, in_dim):
    method forward (line 37) | def forward(self,x):
  class GAM_Module (line 60) | class GAM_Module(nn.Module):
    method __init__ (line 63) | def __init__(self, in_dim):
    method forward (line 72) | def forward(self,x):
  class RCAB (line 100) | class RCAB(nn.Module):
    method __init__ (line 101) | def __init__(
    method forward (line 115) | def forward(self, x):
  class ResidualGroup (line 122) | class ResidualGroup(nn.Module):
    method __init__ (line 123) | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scal...
    method forward (line 133) | def forward(self, x):
  class RCAN (line 139) | class RCAN(nn.Module):
    method __init__ (line 140) | def __init__(self, args, conv=common.default_conv):
    method forward (line 183) | def forward(self, x):
    method load_state_dict (line 220) | def load_state_dict(self, state_dict, strict=False):

FILE: src/model/rdn.py
  function make_model (line 10) | def make_model(args, parent=False):
  class RDB_Conv (line 13) | class RDB_Conv(nn.Module):
    method __init__ (line 14) | def __init__(self, inChannels, growRate, kSize=3):
    method forward (line 23) | def forward(self, x):
  class RDB (line 27) | class RDB(nn.Module):
    method __init__ (line 28) | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
    method forward (line 42) | def forward(self, x):
  class RDN (line 45) | class RDN(nn.Module):
    method __init__ (line 46) | def __init__(self, args):
    method forward (line 93) | def forward(self, x):

FILE: src/model/rdn1.py
  function make_model (line 10) | def make_model(args, parent=False):
  class RDB_Conv (line 13) | class RDB_Conv(nn.Module):
    method __init__ (line 14) | def __init__(self, inChannels, growRate, kSize=(3,3,3)):
    method forward (line 24) | def forward(self, x):
  class DAM_Module (line 31) | class DAM_Module(nn.Module):
    method __init__ (line 33) | def __init__(self):
    method forward (line 39) | def forward(self,x):
  class RDB (line 62) | class RDB(nn.Module):
    method __init__ (line 63) | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
    method forward (line 79) | def forward(self, x):
  class RDN (line 84) | class RDN(nn.Module):
    method __init__ (line 85) | def __init__(self, args):
    method forward (line 133) | def forward(self, x):

FILE: src/model/rdn2.py
  function make_model (line 10) | def make_model(args, parent=False):
  class RDB_Conv (line 13) | class RDB_Conv(nn.Module):
    method __init__ (line 14) | def __init__(self, inChannels, growRate, kSize=3):
    method forward (line 24) | def forward(self, x):
  class DAM_Module (line 33) | class DAM_Module(nn.Module):
    method __init__ (line 35) | def __init__(self):
    method forward (line 41) | def forward(self,x):
  class RDB (line 64) | class RDB(nn.Module):
    method __init__ (line 65) | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
    method forward (line 81) | def forward(self, x):
  class RDN (line 88) | class RDN(nn.Module):
    method __init__ (line 89) | def __init__(self, args):
    method forward (line 137) | def forward(self, x):

FILE: src/model/vdsr.py
  function make_model (line 10) | def make_model(args, parent=False):
  class VDSR (line 13) | class VDSR(nn.Module):
    method __init__ (line 14) | def __init__(self, args, conv=common.default_conv):
    method forward (line 39) | def forward(self, x):

FILE: src/template.py
  function set_template (line 1) | def set_template(args):

FILE: src/trainer.py
  class Trainer (line 12) | class Trainer():
    method __init__ (line 13) | def __init__(self, args, loader, my_model, my_loss, ckp):
    method train (line 30) | def train(self):
    method test (line 78) | def test(self):
    method prepare (line 136) | def prepare(self, *args):
    method terminate (line 144) | def terminate(self):

FILE: src/utility.py
  class timer (line 19) | class timer():
    method __init__ (line 20) | def __init__(self):
    method tic (line 24) | def tic(self):
    method toc (line 27) | def toc(self, restart=False):
    method hold (line 32) | def hold(self):
    method release (line 35) | def release(self):
    method reset (line 41) | def reset(self):
  class checkpoint (line 44) | class checkpoint():
    method __init__ (line 45) | def __init__(self, args):
    method get_path (line 82) | def get_path(self, *subdir):
    method save (line 85) | def save(self, trainer, epoch, is_best=False):
    method add_log (line 94) | def add_log(self, log):
    method write_log (line 97) | def write_log(self, log, refresh=False):
    method done (line 104) | def done(self):
    method plot_psnr (line 107) | def plot_psnr(self, epoch):
    method begin_background (line 126) | def begin_background(self):
    method end_background (line 143) | def end_background(self):
    method save_results (line 148) | def save_results(self, dataset, filename, save_list, scale):
  function quantize (line 161) | def quantize(img, rgb_range):
  function calc_psnr (line 165) | def calc_psnr(sr, hr, scale, rgb_range, dataset=None):
  function make_optimizer (line 183) | def make_optimizer(args, target):

FILE: src/videotester.py
  class VideoTester (line 12) | class VideoTester():
    method __init__ (line 13) | def __init__(self, args, my_model, ckp):
    method test (line 22) | def test(self):
    method prepare (line 65) | def prepare(self, *args):
Condensed preview — 47 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (362K chars).
[
  {
    "path": ".gitignore",
    "chars": 1799,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 4457,
    "preview": "## HAN\n\n> PyTorch code for our ECCV 2020 paper \"Single Image Super-Resolution via a Holistic Attention Network\"\n>\n> This"
  },
  {
    "path": "experiment/.gitignore",
    "chars": 27,
    "preview": "*\n!.gitignore\n!/model/*.pt\n"
  },
  {
    "path": "src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/data/__init__.py",
    "chars": 1978,
    "preview": "from importlib import import_module\n#from dataloader import MSDataLoader\nfrom torch.utils.data import dataloader\nfrom to"
  },
  {
    "path": "src/data/benchmark.py",
    "chars": 1596,
    "preview": "import os\n\nfrom data import common\nfrom data import srdata\n\nimport numpy as np\n\nimport torch\nimport torch.utils.data as "
  },
  {
    "path": "src/data/common.py",
    "chars": 1786,
    "preview": "import random\n\nimport numpy as np\nimport skimage.color as sc\n\nimport torch\n\ndef get_patch(*args, patch_size=96, scale=2,"
  },
  {
    "path": "src/data/demo.py",
    "chars": 1075,
    "preview": "import os\n\nfrom data import common\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\ncla"
  },
  {
    "path": "src/data/div2k.py",
    "chars": 1216,
    "preview": "import os\nfrom data import srdata\n\nclass DIV2K(srdata.SRData):\n    def __init__(self, args, name='DIV2K', train=True, be"
  },
  {
    "path": "src/data/div2kjpeg.py",
    "chars": 675,
    "preview": "import os\nfrom data import srdata\nfrom data import div2k\n\nclass DIV2KJPEG(div2k.DIV2K):\n    def __init__(self, args, nam"
  },
  {
    "path": "src/data/sr291.py",
    "chars": 180,
    "preview": "from data import srdata\n\nclass SR291(srdata.SRData):\n    def __init__(self, args, name='SR291', train=True, benchmark=Fa"
  },
  {
    "path": "src/data/srdata.py",
    "chars": 5506,
    "preview": "import os\nimport glob\nimport random\nimport pickle\n\nfrom data import common\n\nimport numpy as np\nimport imageio\nimport tor"
  },
  {
    "path": "src/data/video.py",
    "chars": 1207,
    "preview": "import os\n\nfrom data import common\n\nimport cv2\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data a"
  },
  {
    "path": "src/dataloader.py",
    "chars": 5259,
    "preview": "import threading\nimport random\n\nimport torch\nimport torch.multiprocessing as multiprocessing\nfrom torch.utils.data impor"
  },
  {
    "path": "src/demo.sh",
    "chars": 4906,
    "preview": "# EDSR baseline model (x2) + JPEG augmentation\n#python3 main.py --model MatrixModel --scale 4 --patch_size 192 --save Ma"
  },
  {
    "path": "src/loss/__init__.py",
    "chars": 4802,
    "preview": "import os\r\nfrom importlib import import_module\r\n\r\nimport matplotlib\r\nmatplotlib.use('Agg')\r\nimport matplotlib.pyplot as "
  },
  {
    "path": "src/loss/adversarial.py",
    "chars": 4393,
    "preview": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\n"
  },
  {
    "path": "src/loss/discriminator.py",
    "chars": 1595,
    "preview": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    '''\n        output is not normalize"
  },
  {
    "path": "src/loss/vgg.py",
    "chars": 1106,
    "preview": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models a"
  },
  {
    "path": "src/main.py",
    "chars": 835,
    "preview": "import torch\n\nimport utility\nimport data\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\n\nt"
  },
  {
    "path": "src/model/__init__.py",
    "chars": 6492,
    "preview": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport t"
  },
  {
    "path": "src/model/common.py",
    "chars": 2782,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef default_conv(in_channels, out_chann"
  },
  {
    "path": "src/model/dcn/__init__.py",
    "chars": 306,
    "preview": "from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,\n                    "
  },
  {
    "path": "src/model/dcn/deform_conv.py",
    "chars": 12234,
    "preview": "import math\nimport logging\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.f"
  },
  {
    "path": "src/model/dcn/setup.py",
    "chars": 711,
    "preview": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\ndef make_cuda_ext(nam"
  },
  {
    "path": "src/model/dcn/src/deform_conv_cuda.cpp",
    "chars": 29235,
    "preview": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/def"
  },
  {
    "path": "src/model/dcn/src/deform_conv_cuda_kernel.cu",
    "chars": 42269,
    "preview": "/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contribu"
  },
  {
    "path": "src/model/ddbpn.py",
    "chars": 3629,
    "preview": "# Deep Back-Projection Networks For Super-Resolution\n# https://arxiv.org/abs/1803.02735\n\nfrom model import common\n\nimpor"
  },
  {
    "path": "src/model/edsr.py",
    "chars": 3031,
    "preview": "from model import common\n\nimport torch.nn as nn\n\nurl = {\n    'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr"
  },
  {
    "path": "src/model/han.py",
    "chars": 8288,
    "preview": "from model import common\nimport torch\nimport torch.nn as nn\nimport pdb\n\ndef make_model(args, parent=False):\n    return H"
  },
  {
    "path": "src/model/matrixmodel.py",
    "chars": 75664,
    "preview": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed un"
  },
  {
    "path": "src/model/mdsr.py",
    "chars": 1926,
    "preview": "from model import common\n\nimport torch.nn as nn\n\nurl = {\n    'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_b"
  },
  {
    "path": "src/model/ops.py",
    "chars": 12835,
    "preview": "'''EoctConv'''\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch\r\nimport numpy as np\r\nimport math\r\ni"
  },
  {
    "path": "src/model/rcan.py",
    "chars": 7347,
    "preview": "from model import common\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport pdb\n\ndef make_model(args, parent=F"
  },
  {
    "path": "src/model/rcan1.py",
    "chars": 17848,
    "preview": "from model import common\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.init as init\nimport pdb\n\ndef make_model(arg"
  },
  {
    "path": "src/model/rcan3.py",
    "chars": 22423,
    "preview": "from model import common\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nimport pdb\nimport numpy "
  },
  {
    "path": "src/model/rcan4.py",
    "chars": 8623,
    "preview": "from model import common\nimport torch\nimport torch.nn as nn\nimport pdb\n\ndef make_model(args, parent=False):\n    return R"
  },
  {
    "path": "src/model/rdn.py",
    "chars": 3201,
    "preview": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport"
  },
  {
    "path": "src/model/rdn1.py",
    "chars": 4564,
    "preview": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport"
  },
  {
    "path": "src/model/rdn2.py",
    "chars": 4712,
    "preview": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport"
  },
  {
    "path": "src/model/vdsr.py",
    "chars": 1275,
    "preview": "from model import common\n\nimport torch.nn as nn\nimport torch.nn.init as init\n\nurl = {\n    'r20f64': ''\n}\n\ndef make_model"
  },
  {
    "path": "src/option.py",
    "chars": 7695,
    "preview": "import argparse\nimport template\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--d"
  },
  {
    "path": "src/template.py",
    "chars": 2037,
    "preview": "def set_template(args):\n    # Set the templates here\n    if args.template.find('jpeg') >= 0:\n        args.data_train = '"
  },
  {
    "path": "src/trainer.py",
    "chars": 4968,
    "preview": "import os\nimport math\nfrom decimal import Decimal\n\nimport utility\n\nimport torch\nimport torch.nn.utils as utils\nfrom tqdm"
  },
  {
    "path": "src/utility.py",
    "chars": 7482,
    "preview": "import os\nimport math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n"
  },
  {
    "path": "src/videotester.py",
    "chars": 2280,
    "preview": "import os\nimport math\n\nimport utility\nfrom data import common\n\nimport torch\nimport cv2\n\nfrom tqdm import tqdm\n\nclass Vid"
  }
]

About this extraction

This page contains the full source code of the wwlCape/HAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 47 files (341.4 KB), approximately 90.5k tokens, and a symbol index with 461 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!