[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "## HAN\n\n> PyTorch code for our ECCV 2020 paper \"Single Image Super-Resolution via a Holistic Attention Network\"\n>\n> This repository is for HAN introduced in the following paper\n>\n> Ben Niu, Weilei Wen, Wenqi Ren, Xiangde Zhang, Lianping Yang, Shuzhen Wang, Kaihao Zhang, Xiaochun Cao, Haifeng Shen, \"Single Image Super-Resolution via a Holistic Attention Network\", ECCV 2020, [arxiv](https://arxiv.org/abs/2008.08767)\n>\n> The code is built on RCAN (PyTorch) and tested on Ubuntu 16.04/18.04 environment (Python3.6, PyTorch_0.4.0, CUDA8.0, cuDNN5.1) with Titan X/1080Ti/Xp GPUs.\n>\n> ### Contents\n>\n> ________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________________\n>\n> > 1. [Introduction](https://github.com/wwlCape/HAN#introduction)\n> > 2. [Train](https://github.com/wwlCape/HAN#begin-to-train)\n> > 3. [Test](https://github.com/wwlCape/HAN#begin-to-test)\n> > 4. [Acknowledgements](https://github.com/wwlCape/HAN#Acknowledgements)\n>\n> ### Introduction\n>\n> Informative features play a crucial role in the single image super-resolution task. Channel attention has been demonstrated to be effective for preserving information-rich features in each layer. However, channel attention treats each convolution layer as a separate process that misses the correlation among different layers. To address this problem, we propose a new holistic attention network (HAN), which consists of a layer attention module (LAM) and a channel-spatial attention module (CSAM), to model the holistic interdependencies among layers, channels, and positions. Specifically, the proposed LAM adaptively emphasizes hierarchical features by considering correlations among layers. Meanwhile, CSAM learns the confidence at all the positions of each channel to selectively capture more informative features. Extensive experiments demonstrate that the proposed HAN performs favorably against the state-of-the-art single image super- resolution approaches.\n>\n>\n> Train\n> Prepare training data\n> Download DIV2K training data (800 training + 100 validtion images) from DIV2K dataset.\n>\n> ### Begin to train\n>\n> (optional) Download models for our paper and place them in '/HAN/experiment/HAN'. All the models (BIX2/3/4/8, BDX3) can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/17cLcPCDLuBV5_5-ngd0vXIDp6rebIMG1). You can use scripts in file 'demo.sh' to train models for our paper.\n>\n> ```python\n> BI, scale 2, 3, 4, 8\n> #HAN BI model (x2)\n> \n> python main.py --template HAN --save HANx2 --scale 2 --reset --save_results --patch_size 96 --pre_train ../experiment/model/RCAN_BIX2.pt\n> \n> #HAN BI model (x3)\n> \n> python main.py --template HAN --save HANx3 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt\n> \n> #HAN BI model (x4)\n> \n> python main.py --template HAN --save HANx4 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt\n> \n> #HAN BI model (x8)\n> \n> python main.py --template HAN --save HANx8 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt\n> \n> \n> ```\n>\n> ### Begin to Test\n>\n> ```python\n> Quick start\n> \n> Download models for our paper and place them in '/experiment/HAN'.\n> \n> Cd to '/HAN/src', run the following scripts.\n> #test\n> python main.py --template HAN --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 2 --pre_train ../experiment/HAN/HAN_BIX2.pt --test_only --save HANx2_test --save_results\n> ```\n>\n> All the models (BIX2/3/4/8, BDX3) can be downloaded from [GoogleDrive](https://drive.google.com/drive/folders/17cLcPCDLuBV5_5-ngd0vXIDp6rebIMG1).\n>\n> The whole test pipeline \n>\n> 1.Prepare test data.\n>\n> Place the original test sets in '/dataset/x4/test'.\n>\n> Run 'Prepare_TestData_HR_LR.m' in Matlab to generate HR/LR images with different degradation models.\n>\n> 2.Conduct image SR.\n>\n> See Quick start\n>\n> 3.Evaluate the results.\n>\n> Run 'Evaluate_PSNR_SSIM.m' to obtain PSNR/SSIM values for paper.\n>\n> ### Acknowledgements\n>\n> This code is built on [RCAN](https://github.com/yulunzhang/RCAN). We thank the authors for sharing their codes of RCAN  [PyTorch version](https://github.com/yulunzhang/RCAN).\n\n"
  },
  {
    "path": "experiment/.gitignore",
    "content": "*\n!.gitignore\n!/model/*.pt\n"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/data/__init__.py",
    "content": "from importlib import import_module\n#from dataloader import MSDataLoader\nfrom torch.utils.data import dataloader\nfrom torch.utils.data import ConcatDataset\n\n# This is a simple wrapper function for ConcatDataset\nclass MyConcatDataset(ConcatDataset):\n    def __init__(self, datasets):\n        super(MyConcatDataset, self).__init__(datasets)\n        self.train = datasets[0].train\n\n    def set_scale(self, idx_scale):\n        for d in self.datasets:\n            if hasattr(d, 'set_scale'): d.set_scale(idx_scale)\n\nclass Data:\n    def __init__(self, args):\n        self.loader_train = None\n        if not args.test_only:\n            datasets = []\n            for d in args.data_train:\n                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'\n                m = import_module('data.' + module_name.lower())\n                datasets.append(getattr(m, module_name)(args, name=d))\n\n            self.loader_train = dataloader.DataLoader(\n                MyConcatDataset(datasets),\n                batch_size=args.batch_size,\n                shuffle=True,\n                pin_memory=not args.cpu,\n                num_workers=args.n_threads,\n            )\n\n        self.loader_test = []\n        for d in args.data_test:\n            if d in ['Val20', 'Set20', 'Set5', 'Set14', 'B100', 'Urban100','Manga109']:\n                m = import_module('data.benchmark')\n                testset = getattr(m, 'Benchmark')(args, train=False, name=d)\n            else:\n                module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG'\n                m = import_module('data.' + module_name.lower())\n                testset = getattr(m, module_name)(args, train=False, name=d)\n\n            self.loader_test.append(\n                dataloader.DataLoader(\n                    testset,\n                    batch_size=1,\n                    shuffle=False,\n                    pin_memory=not args.cpu,\n                    num_workers=args.n_threads,\n                )\n            )\n"
  },
  {
    "path": "src/data/benchmark.py",
    "content": "import os\n\nfrom data import common\nfrom data import srdata\n\nimport numpy as np\n\nimport torch\nimport torch.utils.data as data\nimport glob\nimport pdb\n\nclass Benchmark(srdata.SRData):\n    def __init__(self, args, name='', train=True, benchmark=True):\n        super(Benchmark, self).__init__(\n            args, name=name, train=train, benchmark=True)\n\n    def _scan(self):\n        list_hr = []\n        list_lr = [[] for _ in self.scale]\n        for entry in os.scandir(self.dir_hr):\n            filename = os.path.splitext(entry.name)[0]\n            if \"HR\" in filename:\n                list_hr.append(os.path.join(self.dir_hr, filename + self.ext))\n        #pdb.set_trace()\n        for entry in os.scandir(self.dir_lr):\n            filename = os.path.splitext(entry.name)[0]\n            if \"LR\" in filename:\n                for si, s in enumerate(self.scale):\n                    list_lr[si].append(os.path.join(\n                        self.dir_lr, filename + self.ext))\n\n        list_hr.sort()\n        for l in list_lr:\n            l.sort()\n\n        return list_hr, list_lr\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, self.name)\n        self.all_files = glob.glob(os.path.join(self.apath, 'HR', \"*.png\"))\n        #self.dir_lr = os.path.join(dir_data, self.name, 'Test/3')\n        #self.dir_hr = os.path.join(dir_data, self.name, 'Test/3')\n        self.dir_lr = os.path.join(dir_data, self.name, 'LR','X4')\n        self.dir_hr = os.path.join(dir_data, self.name, 'HR')\n        #self.dir_lr = os.path.join(self.apath, 'LR_bicubic')\n        self.ext = '.png'"
  },
  {
    "path": "src/data/common.py",
    "content": "import random\n\nimport numpy as np\nimport skimage.color as sc\n\nimport torch\n\ndef get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False):\n    ih, iw = args[0].shape[:2]\n\n    if not input_large:\n        p = scale if multi else 1\n        tp = p * patch_size\n        ip = tp // scale\n    else:\n        tp = patch_size\n        ip = patch_size\n\n    ix = random.randrange(0, iw - ip + 1)\n    iy = random.randrange(0, ih - ip + 1)\n\n    if not input_large:\n        tx, ty = scale * ix, scale * iy\n    else:\n        tx, ty = ix, iy\n\n    ret = [\n        args[0][iy:iy + ip, ix:ix + ip, :],\n        *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]]\n    ]\n\n    return ret\n\ndef set_channel(*args, n_channels=3):\n    def _set_channel(img):\n        if img.ndim == 2:\n            img = np.expand_dims(img, axis=2)\n\n        c = img.shape[2]\n        if n_channels == 1 and c == 3:\n            img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)\n        elif n_channels == 3 and c == 1:\n            img = np.concatenate([img] * n_channels, 2)\n\n        return img\n\n    return [_set_channel(a) for a in args]\n\ndef np2Tensor(*args, rgb_range=255):\n    def _np2Tensor(img):\n        np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))\n        tensor = torch.from_numpy(np_transpose).float()\n        tensor.mul_(rgb_range / 255)\n\n        return tensor\n\n    return [_np2Tensor(a) for a in args]\n\ndef augment(*args, hflip=True, rot=True):\n    hflip = hflip and random.random() < 0.5\n    vflip = rot and random.random() < 0.5\n    rot90 = rot and random.random() < 0.5\n\n    def _augment(img):\n        if hflip: img = img[:, ::-1, :]\n        if vflip: img = img[::-1, :, :]\n        if rot90: img = img.transpose(1, 0, 2)\n        \n        return img\n\n    return [_augment(a) for a in args]\n\n"
  },
  {
    "path": "src/data/demo.py",
    "content": "import os\n\nfrom data import common\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\nclass Demo(data.Dataset):\n    def __init__(self, args, name='Demo', train=False, benchmark=False):\n        self.args = args\n        self.name = name\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.train = False\n        self.benchmark = benchmark\n\n        self.filelist = []\n        for f in os.listdir(args.dir_demo):\n            if f.find('.png') >= 0 or f.find('.jp') >= 0:\n                self.filelist.append(os.path.join(args.dir_demo, f))\n        self.filelist.sort()\n\n    def __getitem__(self, idx):\n        filename = os.path.splitext(os.path.basename(self.filelist[idx]))[0]\n        lr = imageio.imread(self.filelist[idx])\n        lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n        lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n\n        return lr_t, -1, filename\n\n    def __len__(self):\n        return len(self.filelist)\n\n    def set_scale(self, idx_scale):\n        self.idx_scale = idx_scale\n\n"
  },
  {
    "path": "src/data/div2k.py",
    "content": "import os\nfrom data import srdata\n\nclass DIV2K(srdata.SRData):\n    def __init__(self, args, name='DIV2K', train=True, benchmark=False):\n        data_range = [r.split('-') for r in args.data_range.split('/')]\n        if train:\n            data_range = data_range[0]\n        else:\n            if args.test_only and len(data_range) == 1:\n                data_range = data_range[0]\n            else:\n                data_range = data_range[1]\n\n        self.begin, self.end = list(map(lambda x: int(x), data_range))\n        super(DIV2K, self).__init__(\n            args, name=name, train=train, benchmark=benchmark\n        )\n\n    def _scan(self):\n        names_hr, names_lr = super(DIV2K, self)._scan()\n        names_hr = names_hr[self.begin - 1:self.end]\n        names_lr = [n[self.begin - 1:self.end] for n in names_lr]\n\n        return names_hr, names_lr\n\n    def _set_filesystem(self, dir_data):\n        super(DIV2K, self)._set_filesystem(dir_data)\n        self.apath = dir_data\n        self.dir_hr = os.path.join(self.apath, 'TrainHR')\n        self.dir_lr = os.path.join(self.apath, 'TrainLR')\n        #self.dir_lr = os.path.join(self.apath, 'dataset/DIV2K_train_HR')\n        if self.input_large: self.dir_lr += 'L'\n\n"
  },
  {
    "path": "src/data/div2kjpeg.py",
    "content": "import os\nfrom data import srdata\nfrom data import div2k\n\nclass DIV2KJPEG(div2k.DIV2K):\n    def __init__(self, args, name='', train=True, benchmark=False):\n        self.q_factor = int(name.replace('DIV2K-Q', ''))\n        super(DIV2KJPEG, self).__init__(\n            args, name=name, train=train, benchmark=benchmark\n        )\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, 'DIV2K')\n        self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR')\n        self.dir_lr = os.path.join(\n            self.apath, 'DIV2K_Q{}'.format(self.q_factor)\n        )\n        if self.input_large: self.dir_lr += 'L'\n        self.ext = ('.png', '.jpg')\n\n"
  },
  {
    "path": "src/data/sr291.py",
    "content": "from data import srdata\n\nclass SR291(srdata.SRData):\n    def __init__(self, args, name='SR291', train=True, benchmark=False):\n        super(SR291, self).__init__(args, name=name)\n\n"
  },
  {
    "path": "src/data/srdata.py",
    "content": "import os\nimport glob\nimport random\nimport pickle\n\nfrom data import common\n\nimport numpy as np\nimport imageio\nimport torch\nimport torch.utils.data as data\nimport pdb\n#import pdb\n\nclass SRData(data.Dataset):\n    def __init__(self, args, name='', train=True, benchmark=False):\n        self.args = args\n        self.name = name\n        self.train = train\n        self.split = 'train' if train else 'test'\n        self.do_eval = True\n        self.benchmark = benchmark\n        self.input_large = (args.model == 'VDSR')\n        self.scale = args.scale\n        self.idx_scale = 0\n        \n        self._set_filesystem(args.dir_data)\n        if args.ext.find('img') < 0:\n            path_bin = os.path.join(self.apath, 'bin')\n            os.makedirs(path_bin, exist_ok=True)\n\n        list_hr, list_lr = self._scan()\n        if args.ext.find('img') >= 0 or benchmark:\n            self.images_hr, self.images_lr = list_hr, list_lr\n        elif args.ext.find('sep') >= 0:\n            os.makedirs(\n                self.dir_hr.replace(self.apath, path_bin),\n                exist_ok=True\n            )\n            for s in self.scale:\n                os.makedirs(\n                    os.path.join(\n                        self.dir_lr.replace(self.apath, path_bin),\n                        'X{}'.format(s)\n                    ),\n                    exist_ok=True\n                )\n            \n            self.images_hr, self.images_lr = [], [[] for _ in self.scale]\n            for h in list_hr:\n                b = h.replace(self.apath, path_bin)\n                b = b.replace(self.ext[0], '.pt')\n                self.images_hr.append(b)\n                self._check_and_load(args.ext, h, b, verbose=True) \n            for i, ll in enumerate(list_lr):\n                for l in ll:\n                    #pdb.set_trace()\n                    b = l.replace(self.apath, path_bin)\n                    b = b.replace(self.ext[1], '.pt')\n                    self.images_lr[i].append(b)\n                    self._check_and_load(args.ext, l, b, verbose=True) \n        if train:\n            n_patches = args.batch_size * args.test_every\n            n_images = len(args.data_train) * len(self.images_hr)\n            if n_images == 0:\n                self.repeat = 0\n            else:\n                self.repeat = max(n_patches // n_images, 1)\n\n    # Below functions as used to prepare images\n    def _scan(self):\n        names_hr = sorted(\n            glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0]))\n        )\n        names_lr = [[] for _ in self.scale]\n        for f in names_hr:\n            filename,_ = os.path.splitext(os.path.basename(f))[0].split('_')\n            for si, s in enumerate(self.scale):\n                names_lr[si].append(os.path.join(\n                    self.dir_lr, 'X{}/{}{}{}'.format(\n                        s, filename, '_LR', self.ext[1]\n                    )\n                ))\n\n        return names_hr, names_lr\n\n    def _set_filesystem(self, dir_data):\n        self.apath = os.path.join(dir_data, self.name)\n        self.dir_hr = os.path.join(self.apath, 'HR')\n        self.dir_lr = os.path.join(self.apath, 'LR_bicubic')\n        if self.input_large: self.dir_lr += 'L'\n        self.ext = ('.png', '.png')\n\n    def _check_and_load(self, ext, img, f, verbose=True):\n        if not os.path.isfile(f) or ext.find('reset') >= 0:\n            if verbose:\n                print('Making a binary: {}'.format(f))\n            with open(f, 'wb') as _f:\n                pickle.dump(imageio.imread(img), _f)\n\n    def __getitem__(self, idx):\n        lr, hr, filename = self._load_file(idx)\n        pair = self.get_patch(lr, hr)\n        pair = common.set_channel(*pair, n_channels=self.args.n_colors)\n        pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range)\n\n        return pair_t[0], pair_t[1], filename\n\n    def __len__(self):\n        if self.train:\n            return len(self.images_hr) * self.repeat\n        else:\n            return len(self.images_hr)\n\n    def _get_index(self, idx):\n        if self.train:\n            return idx % len(self.images_hr)\n        else:\n            return idx\n\n    def _load_file(self, idx):\n        idx = self._get_index(idx)\n        f_hr = self.images_hr[idx]\n        f_lr = self.images_lr[self.idx_scale][idx]\n        #print('！！!！!！!!！',f_lr)\n        #pdb.set_trace()\n\n        filename, _ = os.path.splitext(os.path.basename(f_hr))\n        if self.args.ext == 'img' or self.benchmark:\n            hr = imageio.imread(f_hr)\n            lr = imageio.imread(f_lr)\n        elif self.args.ext.find('sep') >= 0:\n            with open(f_hr, 'rb') as _f:\n                hr = pickle.load(_f)\n            with open(f_lr, 'rb') as _f:\n                lr = pickle.load(_f)\n\n        return lr, hr, filename\n\n    def get_patch(self, lr, hr):\n        scale = self.scale[self.idx_scale]\n        if self.train:\n            lr, hr = common.get_patch(\n                lr, hr,\n                patch_size=self.args.patch_size,\n                scale=scale,\n                multi=(len(self.scale) > 1),\n                input_large=self.input_large\n            )\n            #print(hr.shape)\n            if not self.args.no_augment: lr, hr = common.augment(lr, hr)\n        else:\n            ih, iw = lr.shape[:2]\n            hr = hr[0:ih * scale, 0:iw * scale]\n\n        return lr, hr\n\n    def set_scale(self, idx_scale):\n        if not self.input_large:\n            self.idx_scale = idx_scale\n        else:\n            self.idx_scale = random.randint(0, len(self.scale) - 1)\n\n"
  },
  {
    "path": "src/data/video.py",
    "content": "import os\n\nfrom data import common\n\nimport cv2\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.utils.data as data\n\nclass Video(data.Dataset):\n    def __init__(self, args, name='Video', train=False, benchmark=False):\n        self.args = args\n        self.name = name\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.train = False\n        self.do_eval = False\n        self.benchmark = benchmark\n\n        self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))\n        self.vidcap = cv2.VideoCapture(args.dir_demo)\n        self.n_frames = 0\n        self.total_frames = int(self.vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n\n    def __getitem__(self, idx):\n        success, lr = self.vidcap.read()\n        if success:\n            self.n_frames += 1\n            lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n            lr_t, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n\n            return lr_t, -1, '{}_{:0>5}'.format(self.filename, self.n_frames)\n        else:\n            vidcap.release()\n            return None\n\n    def __len__(self):\n        return self.total_frames\n\n    def set_scale(self, idx_scale):\n        self.idx_scale = idx_scale\n\n"
  },
  {
    "path": "src/dataloader.py",
    "content": "import threading\nimport random\n\nimport torch\nimport torch.multiprocessing as multiprocessing\nfrom torch.utils.data import DataLoader\nfrom torch.utils.data import SequentialSampler\nfrom torch.utils.data import RandomSampler\nfrom torch.utils.data import BatchSampler\nfrom torch.utils.data import _utils\nfrom torch.utils.data.dataloader import _DataLoaderIter\n\nfrom torch.utils.data._utils import collate\nfrom torch.utils.data._utils import signal_handling\nfrom torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL\nfrom torch.utils.data._utils import ExceptionWrapper\nfrom torch.utils.data._utils import IS_WINDOWS\nfrom torch.utils.data._utils.worker import ManagerWatchdog\n\nfrom torch._six import queue\n\ndef _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id):\n    try:\n        collate._use_shared_memory = True\n        signal_handling._set_worker_signal_handlers()\n\n        torch.set_num_threads(1)\n        random.seed(seed)\n        torch.manual_seed(seed)\n\n        data_queue.cancel_join_thread()\n\n        if init_fn is not None:\n            init_fn(worker_id)\n\n        watchdog = ManagerWatchdog()\n\n        while watchdog.is_alive():\n            try:\n                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)\n            except queue.Empty:\n                continue\n\n            if r is None:\n                assert done_event.is_set()\n                return\n            elif done_event.is_set():\n                continue\n\n            idx, batch_indices = r\n            try:\n                idx_scale = 0\n                if len(scale) > 1 and dataset.train:\n                    idx_scale = random.randrange(0, len(scale))\n                    dataset.set_scale(idx_scale)\n\n                samples = collate_fn([dataset[i] for i in batch_indices])\n                samples.append(idx_scale)\n            except Exception:\n                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))\n            else:\n                data_queue.put((idx, samples))\n                del samples\n\n    except KeyboardInterrupt:\n        pass\n\nclass _MSDataLoaderIter(_DataLoaderIter):\n\n    def __init__(self, loader):\n        self.dataset = loader.dataset\n        self.scale = loader.scale\n        self.collate_fn = loader.collate_fn\n        self.batch_sampler = loader.batch_sampler\n        self.num_workers = loader.num_workers\n        self.pin_memory = loader.pin_memory and torch.cuda.is_available()\n        self.timeout = loader.timeout\n\n        self.sample_iter = iter(self.batch_sampler)\n\n        base_seed = torch.LongTensor(1).random_().item()\n\n        if self.num_workers > 0:\n            self.worker_init_fn = loader.worker_init_fn\n            self.worker_queue_idx = 0\n            self.worker_result_queue = multiprocessing.Queue()\n            self.batches_outstanding = 0\n            self.worker_pids_set = False\n            self.shutdown = False\n            self.send_idx = 0\n            self.rcvd_idx = 0\n            self.reorder_dict = {}\n            self.done_event = multiprocessing.Event()\n\n            base_seed = torch.LongTensor(1).random_()[0]\n\n            self.index_queues = []\n            self.workers = []\n            for i in range(self.num_workers):\n                index_queue = multiprocessing.Queue()\n                index_queue.cancel_join_thread()\n                w = multiprocessing.Process(\n                    target=_ms_loop,\n                    args=(\n                        self.dataset,\n                        index_queue,\n                        self.worker_result_queue,\n                        self.done_event,\n                        self.collate_fn,\n                        self.scale,\n                        base_seed + i,\n                        self.worker_init_fn,\n                        i\n                    )\n                )\n                w.daemon = True\n                w.start()\n                self.index_queues.append(index_queue)\n                self.workers.append(w)\n\n            if self.pin_memory:\n                self.data_queue = queue.Queue()\n                pin_memory_thread = threading.Thread(\n                    target=_utils.pin_memory._pin_memory_loop,\n                    args=(\n                        self.worker_result_queue,\n                        self.data_queue,\n                        torch.cuda.current_device(),\n                        self.done_event\n                    )\n                )\n                pin_memory_thread.daemon = True\n                pin_memory_thread.start()\n                self.pin_memory_thread = pin_memory_thread\n            else:\n                self.data_queue = self.worker_result_queue\n\n            _utils.signal_handling._set_worker_pids(\n                id(self), tuple(w.pid for w in self.workers)\n            )\n            _utils.signal_handling._set_SIGCHLD_handler()\n            self.worker_pids_set = True\n\n            for _ in range(2 * self.num_workers):\n                self._put_indices()\n\n\nclass MSDataLoader(DataLoader):\n\n    def __init__(self, cfg, *args, **kwargs):\n        super(MSDataLoader, self).__init__(\n            *args, **kwargs, num_workers=cfg.n_threads\n        )\n        self.scale = cfg.scale\n\n    def __iter__(self):\n        return _MSDataLoaderIter(self)\n\n"
  },
  {
    "path": "src/demo.sh",
    "content": "# EDSR baseline model (x2) + JPEG augmentation\n#python3 main.py --model MatrixModel --scale 4 --patch_size 192 --save MatrixModelG7_x4 --reset --pre_train /media/zrh/cc9cb710-2fc7-4382-81ff-649502a83b92/EDSR-PyTorch-master/experiment/MatrixModelG6_x4/model/model_best.pt\n#python main.py --model EDSR --scale 2 --patch_size 96 --save edsr_baseline_x2 --reset --data_train DIV2K+DIV2K-Q75 --data_test DIV2K+DIV2K-Q75\n\n# EDSR baseline model (x3) - from EDSR baseline model (x2)\n#python main.py --model EDSR --scale 3 --patch_size 144 --save edsr_baseline_x3 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]\n\n# EDSR baseline model (x4) - from EDSR baseline model (x2)\n#python main.py --model EDSR --scale 4 --save edsr_baseline_x4 --reset --pre_train [pre-trained EDSR_baseline_x2 model dir]\n\n# EDSR in the paper (x2)\n#python main.py --model EDSR --scale 2 --save edsr_x2 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset\n\n# EDSR in the paper (x3) - from EDSR (x2)\n#python main.py --model EDSR --scale 3 --save edsr_x3 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR model dir]\n\n# EDSR in the paper (x4) - from EDSR (x2)\n#python main.py --model EDSR --scale 4 --save edsr_x4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --reset --pre_train [pre-trained EDSR_x2 model dir]\n\n# MDSR baseline model\n#python main.py --template MDSR --model MDSR --scale 2+3+4 --save MDSR_baseline --reset --save_models\n\n# MDSR in the paper\n#python main.py --template MDSR --model MDSR --scale 2+3+4 --n_resblocks 80 --save MDSR --reset --save_models\n\n# Standard benchmarks (Ex. EDSR_baseline_x4)\n#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --pre_train download --test_only --self_ensemble\n\n#python main.py --data_test Set5+Set14+B100+Urban100+DIV2K --data_range 801-900 --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train download --test_only --self_ensemble\n\n# Test your own images\n#python main.py --data_test Demo --scale 4 --pre_train download --test_only --save_results\n\n# Advanced - Test with JPEG images \n#python main.py --model MDSR --data_test Demo --scale 2+3+4 --pre_train download --test_only --save_results\n\n# Advanced - Training with adversarial loss\n#python main.py --template GAN --scale 4 --save edsr_gan --reset --patch_size 96 --loss 5*VGG54+0.15*GAN --pre_train download\n\n# RDN BI model (x2)\n#python3.6 main.py --scale 2 --save RDN_D16C8G64_BIx2 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 64 --reset\n# RDN BI model (x3)\n#python3.6 main.py --scale 3 --save RDN_D16C8G64_BIx3 --model RDN --epochs 200 --batch_size 16 --data_range 801-805 --patch_size 96 --reset\n# RDN BI model (x4)\n#python main.py --scale 4 --save RDN9_D16C8G64_BIx4 --model RDN --epochs 400 --batch_size 16 --patch_size 128 --reset #--pre_train /home/visionx/wwl/project/EDSR-PyTorch-master/experiment/RDN7_D16C8G64_BIx4/model/model_best.pt\n\n# RCAN_BIX2_G10R20P48, input=48x48, output=96x96\n# pretrained model can be downloaded from https://www.dropbox.com/s/mjbcqkd4nwhr6nu/models_ECCV2018RCAN.zip?dl=0\n#python main.py --template RCAN --save RCAN_BIX2_G10R20P48 --scale 2 --reset --save_results --patch_size 96\n# RCAN_BIX3_G10R20P48, input=48x48, output=144x144\n#python main.py --template RCAN --save RCAN_BIX3_G10R20P48 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt\n# RCAN_BIX4_G10R20P48, input=48x48, output=192x192\n#python main.py --template RCAN2 --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 8 --pre_train ../experiment/RCAN81_BIX8_G10R20P48/model/model_best.pt --test_only --save RCAN_test --save_results\n#python main.py --template RCAN2 --save RCAN3_BIX4_G10R20P48 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt\n# RCAN_BIX8_G10R20P48, input=48x48, output=384x384\n#python main.py --template RCAN2 --save RCAN81_BIX8_G10R20P48 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX8.pt\n\n# HAN BI model (x2)\n#python main.py --template HAN --save HANx2 --scale 2 --reset --save_results --patch_size 96 --pre_train ../experiment/model/RCAN_BIX2.pt\n# HAN BI model (x3)\n#python main.py --template HAN --save HANx3 --scale 3 --reset --save_results --patch_size 144 --pre_train ../experiment/model/RCAN_BIX2.pt\n# HAN BI model (x4)\n#python main.py --template HAN --save HANx4 --scale 4 --reset --save_results --patch_size 192 --pre_train ../experiment/model/RCAN_BIX2.pt\n# HAN BI model (x8)\n#python main.py --template HAN --save HANx8 --scale 8 --reset --save_results --patch_size 384 --pre_train ../experiment/model/RCAN_BIX2.pt\n# Test HAN\n#python main.py --template HAN --data_test Set5+Set14+B100+Urban100+Manga109 --data_range 801-900 --scale 2 --pre_train ../experiment/HAN/HAN_BIX2.pt --test_only --save HANx2_test --save_results\n"
  },
  {
    "path": "src/loss/__init__.py",
    "content": "import os\r\nfrom importlib import import_module\r\n\r\nimport matplotlib\r\nmatplotlib.use('Agg')\r\nimport matplotlib.pyplot as plt\r\n\r\nimport numpy as np\r\n\r\nimport torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nclass Loss(nn.modules.loss._Loss):\r\n    def __init__(self, args, ckp):\r\n        super(Loss, self).__init__()\r\n        print('Preparing loss function:')\r\n\r\n        self.n_GPUs = args.n_GPUs\r\n        self.loss = []\r\n        self.loss_module = nn.ModuleList()\r\n        for loss in args.loss.split('+'):\r\n            weight, loss_type = loss.split('*')\r\n            if loss_type == 'MSE':\r\n                loss_function = nn.MSELoss()\r\n            elif loss_type == 'L1':\r\n                loss_function = nn.L1Loss()\r\n            elif loss_type.find('VGG') >= 0:\r\n                module = import_module('loss.vgg')\r\n                loss_function = getattr(module, 'VGG')(\r\n                    loss_type[3:],\r\n                    rgb_range=args.rgb_range\r\n                )\r\n            elif loss_type.find('GAN') >= 0:\r\n                module = import_module('loss.adversarial')\r\n                loss_function = getattr(module, 'Adversarial')(\r\n                    args,\r\n                    loss_type\r\n                )\r\n\r\n            self.loss.append({\r\n                'type': loss_type,\r\n                'weight': float(weight),\r\n                'function': loss_function}\r\n            )\r\n            if loss_type.find('GAN') >= 0:\r\n                self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})\r\n\r\n        if len(self.loss) > 1:\r\n            self.loss.append({'type': 'Total', 'weight': 0, 'function': None})\r\n\r\n        for l in self.loss:\r\n            if l['function'] is not None:\r\n                print('{:.3f} * {}'.format(l['weight'], l['type']))\r\n                self.loss_module.append(l['function'])\r\n\r\n        self.log = torch.Tensor()\r\n\r\n        device = torch.device('cpu' if args.cpu else 'cuda')\r\n        self.loss_module.to(device)\r\n        if args.precision == 'half': self.loss_module.half()\r\n        if not args.cpu and args.n_GPUs > 1:\r\n            self.loss_module = nn.DataParallel(\r\n                self.loss_module, range(args.n_GPUs)\r\n            )\r\n\r\n        if args.load != '': self.load(ckp.dir, cpu=args.cpu)\r\n\r\n    def forward(self, sr, hr):\r\n        losses = []\r\n        for i, l in enumerate(self.loss):\r\n            if l['function'] is not None:\r\n                loss = l['function'](sr, hr)\r\n                effective_loss = l['weight'] * loss\r\n                losses.append(effective_loss)\r\n                self.log[-1, i] += effective_loss.item()\r\n            elif l['type'] == 'DIS':\r\n                self.log[-1, i] += self.loss[i - 1]['function'].loss\r\n\r\n        loss_sum = sum(losses)\r\n        if len(self.loss) > 1:\r\n            self.log[-1, -1] += loss_sum.item()\r\n\r\n        return loss_sum\r\n\r\n    def step(self):\r\n        for l in self.get_loss_module():\r\n            if hasattr(l, 'scheduler'):\r\n                l.scheduler.step()\r\n\r\n    def start_log(self):\r\n        self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))\r\n\r\n    def end_log(self, n_batches):\r\n        self.log[-1].div_(n_batches)\r\n\r\n    def display_loss(self, batch):\r\n        n_samples = batch + 1\r\n        log = []\r\n        for l, c in zip(self.loss, self.log[-1]):\r\n            log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))\r\n\r\n        return ''.join(log)\r\n\r\n    def plot_loss(self, apath, epoch):\r\n        axis = np.linspace(1, epoch, epoch)\r\n        for i, l in enumerate(self.loss):\r\n            label = '{} Loss'.format(l['type'])\r\n            fig = plt.figure()\r\n            plt.title(label)\r\n            plt.plot(axis, self.log[:, i].numpy(), label=label)\r\n            plt.legend()\r\n            plt.xlabel('Epochs')\r\n            plt.ylabel('Loss')\r\n            plt.grid(True)\r\n            plt.savefig(os.path.join(apath, 'loss_{}.pdf'.format(l['type'])))\r\n            plt.close(fig)\r\n\r\n    def get_loss_module(self):\r\n        if self.n_GPUs == 1:\r\n            return self.loss_module\r\n        else:\r\n            return self.loss_module.module\r\n\r\n    def save(self, apath):\r\n        torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))\r\n        torch.save(self.log, os.path.join(apath, 'loss_log.pt'))\r\n\r\n    def load(self, apath, cpu=False):\r\n        if cpu:\r\n            kwargs = {'map_location': lambda storage, loc: storage}\r\n        else:\r\n            kwargs = {}\r\n\r\n        self.load_state_dict(torch.load(\r\n            os.path.join(apath, 'loss.pt'),\r\n            **kwargs\r\n        ))\r\n        self.log = torch.load(os.path.join(apath, 'loss_log.pt'))\r\n        for l in self.get_loss_module():\r\n            if hasattr(l, 'scheduler'):\r\n                for _ in range(len(self.log)): l.scheduler.step()\r\n\r\n"
  },
  {
    "path": "src/loss/adversarial.py",
    "content": "import utility\nfrom types import SimpleNamespace\n\nfrom model import common\nfrom loss import discriminator\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\n\nclass Adversarial(nn.Module):\n    def __init__(self, args, gan_type):\n        super(Adversarial, self).__init__()\n        self.gan_type = gan_type\n        self.gan_k = args.gan_k\n        self.dis = discriminator.Discriminator(args)\n        if gan_type == 'WGAN_GP':\n            # see https://arxiv.org/pdf/1704.00028.pdf pp.4\n            optim_dict = {\n                'optimizer': 'ADAM',\n                'betas': (0, 0.9),\n                'epsilon': 1e-8,\n                'lr': 1e-5,\n                'weight_decay': args.weight_decay,\n                'decay': args.decay,\n                'gamma': args.gamma\n            }\n            optim_args = SimpleNamespace(**optim_dict)\n        else:\n            optim_args = args\n\n        self.optimizer = utility.make_optimizer(optim_args, self.dis)\n\n    def forward(self, fake, real):\n        # updating discriminator...\n        self.loss = 0\n        fake_detach = fake.detach()     # do not backpropagate through G\n        for _ in range(self.gan_k):\n            self.optimizer.zero_grad()\n            # d: B x 1 tensor\n            d_fake = self.dis(fake_detach)\n            d_real = self.dis(real)\n            retain_graph = False\n            if self.gan_type == 'GAN':\n                loss_d = self.bce(d_real, d_fake)\n            elif self.gan_type.find('WGAN') >= 0:\n                loss_d = (d_fake - d_real).mean()\n                if self.gan_type.find('GP') >= 0:\n                    epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)\n                    hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)\n                    hat.requires_grad = True\n                    d_hat = self.dis(hat)\n                    gradients = torch.autograd.grad(\n                        outputs=d_hat.sum(), inputs=hat,\n                        retain_graph=True, create_graph=True, only_inputs=True\n                    )[0]\n                    gradients = gradients.view(gradients.size(0), -1)\n                    gradient_norm = gradients.norm(2, dim=1)\n                    gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()\n                    loss_d += gradient_penalty\n            # from ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks\n            elif self.gan_type == 'RGAN':\n                better_real = d_real - d_fake.mean(dim=0, keepdim=True)\n                better_fake = d_fake - d_real.mean(dim=0, keepdim=True)\n                loss_d = self.bce(better_real, better_fake)\n                retain_graph = True\n\n            # Discriminator update\n            self.loss += loss_d.item()\n            loss_d.backward(retain_graph=retain_graph)\n            self.optimizer.step()\n\n            if self.gan_type == 'WGAN':\n                for p in self.dis.parameters():\n                    p.data.clamp_(-1, 1)\n\n        self.loss /= self.gan_k\n\n        # updating generator...\n        d_fake_bp = self.dis(fake)      # for backpropagation, use fake as it is\n        if self.gan_type == 'GAN':\n            label_real = torch.ones_like(d_fake_bp)\n            loss_g = F.binary_cross_entropy_with_logits(d_fake_bp, label_real)\n        elif self.gan_type.find('WGAN') >= 0:\n            loss_g = -d_fake_bp.mean()\n        elif self.gan_type == 'RGAN':\n            better_real = d_real - d_fake_bp.mean(dim=0, keepdim=True)\n            better_fake = d_fake_bp - d_real.mean(dim=0, keepdim=True)\n            loss_g = self.bce(better_fake, better_real)\n\n        # Generator loss\n        return loss_g\n    \n    def state_dict(self, *args, **kwargs):\n        state_discriminator = self.dis.state_dict(*args, **kwargs)\n        state_optimizer = self.optimizer.state_dict()\n\n        return dict(**state_discriminator, **state_optimizer)\n\n    def bce(self, real, fake):\n        label_real = torch.ones_like(real)\n        label_fake = torch.zeros_like(fake)\n        bce_real = F.binary_cross_entropy_with_logits(real, label_real)\n        bce_fake = F.binary_cross_entropy_with_logits(fake, label_fake)\n        bce_loss = bce_real + bce_fake\n        return bce_loss\n               \n# Some references\n# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py\n# OR\n# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py\n"
  },
  {
    "path": "src/loss/discriminator.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nclass Discriminator(nn.Module):\n    '''\n        output is not normalized\n    '''\n    def __init__(self, args):\n        super(Discriminator, self).__init__()\n\n        in_channels = args.n_colors\n        out_channels = 64\n        depth = 7\n\n        def _block(_in_channels, _out_channels, stride=1):\n            return nn.Sequential(\n                nn.Conv2d(\n                    _in_channels,\n                    _out_channels,\n                    3,\n                    padding=1,\n                    stride=stride,\n                    bias=False\n                ),\n                nn.BatchNorm2d(_out_channels),\n                nn.LeakyReLU(negative_slope=0.2, inplace=True)\n            )\n\n        m_features = [_block(in_channels, out_channels)]\n        for i in range(depth):\n            in_channels = out_channels\n            if i % 2 == 1:\n                stride = 1\n                out_channels *= 2\n            else:\n                stride = 2\n            m_features.append(_block(in_channels, out_channels, stride=stride))\n\n        patch_size = args.patch_size // (2**((depth + 1) // 2))\n        m_classifier = [\n            nn.Linear(out_channels * patch_size**2, 1024),\n            nn.LeakyReLU(negative_slope=0.2, inplace=True),\n            nn.Linear(1024, 1)\n        ]\n\n        self.features = nn.Sequential(*m_features)\n        self.classifier = nn.Sequential(*m_classifier)\n\n    def forward(self, x):\n        features = self.features(x)\n        output = self.classifier(features.view(features.size(0), -1))\n\n        return output\n\n"
  },
  {
    "path": "src/loss/vgg.py",
    "content": "from model import common\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torchvision.models as models\n\nclass VGG(nn.Module):\n    def __init__(self, conv_index, rgb_range=1):\n        super(VGG, self).__init__()\n        vgg_features = models.vgg19(pretrained=True).features\n        modules = [m for m in vgg_features]\n        if conv_index.find('22') >= 0:\n            self.vgg = nn.Sequential(*modules[:8])\n        elif conv_index.find('54') >= 0:\n            self.vgg = nn.Sequential(*modules[:35])\n\n        vgg_mean = (0.485, 0.456, 0.406)\n        vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)\n        self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)\n        for p in self.parameters():\n            p.requires_grad = False\n\n    def forward(self, sr, hr):\n        def _forward(x):\n            x = self.sub_mean(x)\n            x = self.vgg(x)\n            return x\n            \n        vgg_sr = _forward(sr)\n        with torch.no_grad():\n            vgg_hr = _forward(hr.detach())\n\n        loss = F.mse_loss(vgg_sr, vgg_hr)\n\n        return loss\n"
  },
  {
    "path": "src/main.py",
    "content": "import torch\n\nimport utility\nimport data\nimport model\nimport loss\nfrom option import args\nfrom trainer import Trainer\n\ntorch.manual_seed(args.seed)\ncheckpoint = utility.checkpoint(args)\n\ndef main():\n    global model\n    if args.data_test == ['video']:\n        from videotester import VideoTester\n        model = model.Model(args, checkpoint)\n        t = VideoTester(args, model, checkpoint)\n        t.test()\n    else:\n        if checkpoint.ok:\n            loader = data.Data(args)\n            _model = model.Model(args, checkpoint)\n            _loss = loss.Loss(args, checkpoint) if not args.test_only else None\n            t = Trainer(args, loader, _model, _loss, checkpoint)\n            while not t.terminate():\n                t.train()\n                t.test()\n\n            checkpoint.done()\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "src/model/__init__.py",
    "content": "import os\nfrom importlib import import_module\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel as P\nimport torch.utils.model_zoo\nos.environ[\"CUDA_VISIBLE_DEVICES\"] = '0,1'\n\nclass Model(nn.Module):\n    def __init__(self, args, ckp):\n        super(Model, self).__init__()\n        print('Making model...')\n\n        self.scale = args.scale\n        self.idx_scale = 0\n        self.input_large = (args.model == 'VDSR')\n        self.self_ensemble = args.self_ensemble\n        self.chop = args.chop\n        self.precision = args.precision\n        self.cpu = args.cpu\n        self.device = torch.device('cpu' if args.cpu else 'cuda')\n        self.n_GPUs = args.n_GPUs\n        self.save_models = args.save_models\n\n        module = import_module('model.' + args.model.lower())\n        self.model = module.make_model(args).to(self.device)\n        if args.precision == 'half':\n            self.model.half()\n\n        self.load(\n            ckp.get_path('model'),\n            pre_train=args.pre_train,\n            resume=args.resume,\n            cpu=args.cpu\n        )\n        print(self.model, file=ckp.log_file)\n\n    def forward(self, x, idx_scale):\n        self.idx_scale = idx_scale\n        if hasattr(self.model, 'set_scale'):\n            self.model.set_scale(idx_scale)\n\n        if self.training:\n            if self.n_GPUs > 1:\n                return P.data_parallel(self.model, x, range(self.n_GPUs))\n            else:\n                return self.model(x)\n        else:\n            if self.chop:\n                forward_function = self.forward_chop\n            else:\n                forward_function = self.model.forward\n\n            if self.self_ensemble:\n                return self.forward_x8(x, forward_function=forward_function)\n            else:\n                return forward_function(x)\n\n    def save(self, apath, epoch, is_best=False):\n        save_dirs = [os.path.join(apath, 'model_latest.pt')]\n\n        if is_best:\n            save_dirs.append(os.path.join(apath, 'model_best.pt'))\n        if self.save_models:\n            save_dirs.append(\n                os.path.join(apath, 'model_{}.pt'.format(epoch))\n            )\n\n        for s in save_dirs:\n            torch.save(self.model.state_dict(), s)\n\n    def load(self, apath, pre_train='', resume=-1, cpu=False):\n        load_from = None\n        kwargs = {}\n        if cpu:\n            kwargs = {'map_location': lambda storage, loc: storage}\n\n        if resume == -1:\n            load_from = torch.load(\n                os.path.join(apath, 'model_latest.pt'),\n                **kwargs\n            )\n        elif resume == 0:\n            if pre_train == 'download':\n                print('Download the model')\n                dir_model = os.path.join('..', 'models')\n                os.makedirs(dir_model, exist_ok=True)\n                load_from = torch.utils.model_zoo.load_url(\n                    self.model.url,\n                    model_dir=dir_model,\n                    **kwargs\n                )\n            elif pre_train:\n                print('Load the model from {}'.format(pre_train))\n                load_from = torch.load(pre_train, **kwargs)\n        else:\n            load_from = torch.load(\n                os.path.join(apath, 'model_{}.pt'.format(resume)),\n                **kwargs\n            )\n\n        if load_from:\n            self.model.load_state_dict(load_from, strict=False)\n\n    def forward_chop(self, x, shave=10, min_size=160000):\n        scale = self.scale[self.idx_scale]\n        n_GPUs = min(self.n_GPUs, 4)\n        b, c, h, w = x.size()\n        h_half, w_half = h // 2, w // 2\n        h_size, w_size = h_half + shave, w_half + shave\n        lr_list = [\n            x[:, :, 0:h_size, 0:w_size],\n            x[:, :, 0:h_size, (w - w_size):w],\n            x[:, :, (h - h_size):h, 0:w_size],\n            x[:, :, (h - h_size):h, (w - w_size):w]]\n\n        if w_size * h_size < min_size:\n            sr_list = []\n            for i in range(0, 4, n_GPUs):\n                lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)\n                sr_batch = self.model(lr_batch)\n                sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))\n        else:\n            sr_list = [\n                self.forward_chop(patch, shave=shave, min_size=min_size) \\\n                for patch in lr_list\n            ]\n\n        h, w = scale * h, scale * w\n        h_half, w_half = scale * h_half, scale * w_half\n        h_size, w_size = scale * h_size, scale * w_size\n        shave *= scale\n\n        output = x.new(b, c, h, w)\n        output[:, :, 0:h_half, 0:w_half] \\\n            = sr_list[0][:, :, 0:h_half, 0:w_half]\n        output[:, :, 0:h_half, w_half:w] \\\n            = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]\n        output[:, :, h_half:h, 0:w_half] \\\n            = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]\n        output[:, :, h_half:h, w_half:w] \\\n            = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]\n\n        return output\n\n    def forward_x8(self, *args, forward_function=None):\n        def _transform(v, op):\n            if self.precision != 'single': v = v.float()\n\n            v2np = v.data.cpu().numpy()\n            if op == 'v':\n                tfnp = v2np[:, :, :, ::-1].copy()\n            elif op == 'h':\n                tfnp = v2np[:, :, ::-1, :].copy()\n            elif op == 't':\n                tfnp = v2np.transpose((0, 1, 3, 2)).copy()\n\n            ret = torch.Tensor(tfnp).to(self.device)\n            if self.precision == 'half': ret = ret.half()\n\n            return ret\n\n        list_x = []\n        for a in args:\n            x = [a]\n            for tf in 'v', 'h', 't': x.extend([_transform(_x, tf) for _x in x])\n\n            list_x.append(x)\n\n        list_y = []\n        for x in zip(*list_x):\n            y = forward_function(*x)\n            if not isinstance(y, list): y = [y]\n            if not list_y:\n                list_y = [[_y] for _y in y]\n            else:\n                for _list_y, _y in zip(list_y, y): _list_y.append(_y)\n\n        for _list_y in list_y:\n            for i in range(len(_list_y)):\n                if i > 3:\n                    _list_y[i] = _transform(_list_y[i], 't')\n                if i % 4 > 1:\n                    _list_y[i] = _transform(_list_y[i], 'h')\n                if (i % 4) % 2 == 1:\n                    _list_y[i] = _transform(_list_y[i], 'v')\n\n        y = [torch.cat(_y, dim=0).mean(dim=0, keepdim=True) for _y in list_y]\n        if len(y) == 1: y = y[0]\n\n        return y\n"
  },
  {
    "path": "src/model/common.py",
    "content": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndef default_conv(in_channels, out_channels, kernel_size, bias=True):\n    return nn.Conv2d(\n        in_channels, out_channels, kernel_size,\n        padding=(kernel_size//2), bias=bias)\n\nclass MeanShift(nn.Conv2d):\n    def __init__(\n        self, rgb_range,\n        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):\n\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\n        std = torch.Tensor(rgb_std)\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std\n        for p in self.parameters():\n            p.requires_grad = False\n\nclass BasicBlock(nn.Sequential):\n    def __init__(\n        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,\n        bn=True, act=nn.ReLU(True)):\n\n        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]\n        if bn:\n            m.append(nn.BatchNorm2d(out_channels))\n        if act is not None:\n            m.append(act)\n\n        super(BasicBlock, self).__init__(*m)\n\nclass ResBlock(nn.Module):\n    def __init__(\n        self, conv, n_feats, kernel_size,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if i == 0:\n                m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\nclass Upsampler(nn.Sequential):\n    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):\n\n        m = []\n        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?\n            for _ in range(int(math.log(scale, 2))):\n                m.append(conv(n_feats, 4 * n_feats, 3, bias))\n                m.append(nn.PixelShuffle(2))\n                if bn:\n                    m.append(nn.BatchNorm2d(n_feats))\n                if act == 'relu':\n                    m.append(nn.ReLU(True))\n                elif act == 'prelu':\n                    m.append(nn.PReLU(n_feats))\n\n        elif scale == 3:\n            m.append(conv(n_feats, 9 * n_feats, 3, bias))\n            m.append(nn.PixelShuffle(3))\n            if bn:\n                m.append(nn.BatchNorm2d(n_feats))\n            if act == 'relu':\n                m.append(nn.ReLU(True))\n            elif act == 'prelu':\n                m.append(nn.PReLU(n_feats))\n        else:\n            raise NotImplementedError\n\n        super(Upsampler, self).__init__(*m)\n\n"
  },
  {
    "path": "src/model/dcn/__init__.py",
    "content": "from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,\n                          deform_conv, modulated_deform_conv)\n\n__all__ = [\n    'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',\n    'modulated_deform_conv'\n]\n"
  },
  {
    "path": "src/model/dcn/deform_conv.py",
    "content": "import math\nimport logging\n\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Function\nfrom torch.autograd.function import once_differentiable\nfrom torch.nn.modules.utils import _pair\n\nfrom . import deform_conv_cuda\n\nlogger = logging.getLogger('base')\n\n\nclass DeformConvFunction(Function):\n    @staticmethod\n    def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,\n                deformable_groups=1, im2col_step=64):\n        if input is not None and input.dim() != 4:\n            raise ValueError(\"Expected 4D tensor as input, got {}D tensor instead.\".format(\n                input.dim()))\n        ctx.stride = _pair(stride)\n        ctx.padding = _pair(padding)\n        ctx.dilation = _pair(dilation)\n        ctx.groups = groups\n        ctx.deformable_groups = deformable_groups\n        ctx.im2col_step = im2col_step\n\n        ctx.save_for_backward(input, offset, weight)\n\n        output = input.new_empty(\n            DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))\n\n        ctx.bufs_ = [input.new_empty(0), input.new_empty(0)]  # columns, ones\n\n        if not input.is_cuda:\n            raise NotImplementedError\n        else:\n            cur_im2col_step = min(ctx.im2col_step, input.shape[0])\n            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'\n            deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,\n                                                      ctx.bufs_[0], ctx.bufs_[1], weight.size(3),\n                                                      weight.size(2), ctx.stride[1], ctx.stride[0],\n                                                      ctx.padding[1], ctx.padding[0],\n                                                      ctx.dilation[1], ctx.dilation[0], ctx.groups,\n                                                      ctx.deformable_groups, cur_im2col_step)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        input, offset, weight = ctx.saved_tensors\n\n        grad_input = grad_offset = grad_weight = None\n\n        if not grad_output.is_cuda:\n            raise NotImplementedError\n        else:\n            cur_im2col_step = min(ctx.im2col_step, input.shape[0])\n            assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'\n\n            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:\n                grad_input = torch.zeros_like(input)\n                grad_offset = torch.zeros_like(offset)\n                deform_conv_cuda.deform_conv_backward_input_cuda(\n                    input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],\n                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],\n                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,\n                    ctx.deformable_groups, cur_im2col_step)\n\n            if ctx.needs_input_grad[2]:\n                grad_weight = torch.zeros_like(weight)\n                deform_conv_cuda.deform_conv_backward_parameters_cuda(\n                    input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],\n                    weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],\n                    ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,\n                    ctx.deformable_groups, 1, cur_im2col_step)\n\n        return (grad_input, grad_offset, grad_weight, None, None, None, None, None)\n\n    @staticmethod\n    def _output_size(input, weight, padding, dilation, stride):\n        channels = weight.size(0)\n        output_size = (input.size(0), channels)\n        for d in range(input.dim() - 2):\n            in_size = input.size(d + 2)\n            pad = padding[d]\n            kernel = dilation[d] * (weight.size(d + 2) - 1) + 1\n            stride_ = stride[d]\n            output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )\n        if not all(map(lambda s: s > 0, output_size)):\n            raise ValueError(\"convolution input is too small (output would be {})\".format('x'.join(\n                map(str, output_size))))\n        return output_size\n\n\nclass ModulatedDeformConvFunction(Function):\n    @staticmethod\n    def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,\n                groups=1, deformable_groups=1):\n        ctx.stride = stride\n        ctx.padding = padding\n        ctx.dilation = dilation\n        ctx.groups = groups\n        ctx.deformable_groups = deformable_groups\n        ctx.with_bias = bias is not None\n        if not ctx.with_bias:\n            bias = input.new_empty(1)  # fake tensor\n        if not input.is_cuda:\n            raise NotImplementedError\n        if weight.requires_grad or mask.requires_grad or offset.requires_grad \\\n                or input.requires_grad:\n            ctx.save_for_backward(input, offset, mask, weight, bias)\n        output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))\n        ctx._bufs = [input.new_empty(0), input.new_empty(0)]\n        deform_conv_cuda.modulated_deform_conv_cuda_forward(\n            input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],\n            weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,\n            ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)\n        return output\n\n    @staticmethod\n    @once_differentiable\n    def backward(ctx, grad_output):\n        if not grad_output.is_cuda:\n            raise NotImplementedError\n        input, offset, mask, weight, bias = ctx.saved_tensors\n        grad_input = torch.zeros_like(input)\n        grad_offset = torch.zeros_like(offset)\n        grad_mask = torch.zeros_like(mask)\n        grad_weight = torch.zeros_like(weight)\n        grad_bias = torch.zeros_like(bias)\n        deform_conv_cuda.modulated_deform_conv_cuda_backward(\n            input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,\n            grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],\n            ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,\n            ctx.groups, ctx.deformable_groups, ctx.with_bias)\n        if not ctx.with_bias:\n            grad_bias = None\n\n        return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,\n                None)\n\n    @staticmethod\n    def _infer_shape(ctx, input, weight):\n        n = input.size(0)\n        channels_out = weight.size(0)\n        height, width = input.shape[2:4]\n        kernel_h, kernel_w = weight.shape[2:4]\n        height_out = (height + 2 * ctx.padding - (ctx.dilation *\n                                                  (kernel_h - 1) + 1)) // ctx.stride + 1\n        width_out = (width + 2 * ctx.padding - (ctx.dilation *\n                                                (kernel_w - 1) + 1)) // ctx.stride + 1\n        return n, channels_out, height_out, width_out\n\n\ndeform_conv = DeformConvFunction.apply\nmodulated_deform_conv = ModulatedDeformConvFunction.apply\n\n\nclass DeformConv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,\n                 groups=1, deformable_groups=1, bias=False):\n        super(DeformConv, self).__init__()\n\n        assert not bias\n        assert in_channels % groups == 0, \\\n            'in_channels {} cannot be divisible by groups {}'.format(\n                in_channels, groups)\n        assert out_channels % groups == 0, \\\n            'out_channels {} cannot be divisible by groups {}'.format(\n                out_channels, groups)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _pair(kernel_size)\n        self.stride = _pair(stride)\n        self.padding = _pair(padding)\n        self.dilation = _pair(dilation)\n        self.groups = groups\n        self.deformable_groups = deformable_groups\n\n        self.weight = nn.Parameter(\n            torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1. / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n\n    def forward(self, x, offset):\n        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,\n                           self.groups, self.deformable_groups)\n\n\nclass DeformConvPack(DeformConv):\n    def __init__(self, *args, **kwargs):\n        super(DeformConvPack, self).__init__(*args, **kwargs)\n\n        self.conv_offset = nn.Conv2d(\n            self.in_channels,\n            self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],\n            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),\n            bias=True)\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset.weight.data.zero_()\n        self.conv_offset.bias.data.zero_()\n\n    def forward(self, x):\n        offset = self.conv_offset(x)\n        return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,\n                           self.groups, self.deformable_groups)\n\n\nclass ModulatedDeformConv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,\n                 groups=1, deformable_groups=1, bias=True):\n        super(ModulatedDeformConv, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = _pair(kernel_size)\n        self.stride = stride\n        self.padding = padding\n        self.dilation = dilation\n        self.groups = groups\n        self.deformable_groups = deformable_groups\n        self.with_bias = bias\n\n        self.weight = nn.Parameter(\n            torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))\n        if bias:\n            self.bias = nn.Parameter(torch.Tensor(out_channels))\n        else:\n            self.register_parameter('bias', None)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        n = self.in_channels\n        for k in self.kernel_size:\n            n *= k\n        stdv = 1. / math.sqrt(n)\n        self.weight.data.uniform_(-stdv, stdv)\n        if self.bias is not None:\n            self.bias.data.zero_()\n\n    def forward(self, x, offset, mask):\n        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,\n                                     self.padding, self.dilation, self.groups,\n                                     self.deformable_groups)\n\n\nclass ModulatedDeformConvPack(ModulatedDeformConv):\n    def __init__(self, *args, extra_offset_mask=False, **kwargs):\n        super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)\n\n        self.extra_offset_mask = extra_offset_mask\n        self.conv_offset_mask = nn.Conv2d(\n            self.in_channels,\n            self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],\n            kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),\n            bias=True)\n        self.init_offset()\n\n    def init_offset(self):\n        self.conv_offset_mask.weight.data.zero_()\n        self.conv_offset_mask.bias.data.zero_()\n\n    def forward(self, x):\n        if self.extra_offset_mask:\n            # x = [input, features]\n            out = self.conv_offset_mask(x[1])\n            x = x[0]\n        else:\n            out = self.conv_offset_mask(x)\n        o1, o2, mask = torch.chunk(out, 3, dim=1)\n        offset = torch.cat((o1, o2), dim=1)\n        mask = torch.sigmoid(mask)\n\n        offset_mean = torch.mean(torch.abs(offset))\n        if offset_mean > 100:\n            logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))\n\n        return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,\n                                     self.padding, self.dilation, self.groups,\n                                     self.deformable_groups)\n"
  },
  {
    "path": "src/model/dcn/setup.py",
    "content": "from setuptools import setup\nfrom torch.utils.cpp_extension import BuildExtension, CUDAExtension\n\n\ndef make_cuda_ext(name, sources):\n\n    return CUDAExtension(\n        name='{}'.format(name), sources=[p for p in sources], extra_compile_args={\n            'cxx': [],\n            'nvcc': [\n                '-D__CUDA_NO_HALF_OPERATORS__',\n                '-D__CUDA_NO_HALF_CONVERSIONS__',\n                '-D__CUDA_NO_HALF2_OPERATORS__',\n            ]\n        })\n\n\nsetup(\n    name='deform_conv', ext_modules=[\n        make_cuda_ext(name='deform_conv_cuda',\n                      sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])\n    ], cmdclass={'build_ext': BuildExtension}, zip_safe=False)\n"
  },
  {
    "path": "src/model/dcn/src/deform_conv_cuda.cpp",
    "content": "// modify from\n// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c\n\n#include <torch/extension.h>\n\n#include <cmath>\n#include <vector>\n\nvoid deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,\n                       const int channels, const int height, const int width,\n                       const int ksize_h, const int ksize_w, const int pad_h,\n                       const int pad_w, const int stride_h, const int stride_w,\n                       const int dilation_h, const int dilation_w,\n                       const int parallel_imgs, const int deformable_group,\n                       at::Tensor data_col);\n\nvoid deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,\n                       const int channels, const int height, const int width,\n                       const int ksize_h, const int ksize_w, const int pad_h,\n                       const int pad_w, const int stride_h, const int stride_w,\n                       const int dilation_h, const int dilation_w,\n                       const int parallel_imgs, const int deformable_group,\n                       at::Tensor grad_im);\n\nvoid deformable_col2im_coord(\n    const at::Tensor data_col, const at::Tensor data_im,\n    const at::Tensor data_offset, const int channels, const int height,\n    const int width, const int ksize_h, const int ksize_w, const int pad_h,\n    const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w, const int parallel_imgs,\n    const int deformable_group, at::Tensor grad_offset);\n\nvoid modulated_deformable_im2col_cuda(\n    const at::Tensor data_im, const at::Tensor data_offset,\n    const at::Tensor data_mask, const int batch_size, const int channels,\n    const int height_im, const int width_im, const int height_col,\n    const int width_col, const int kernel_h, const int kenerl_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w, const int deformable_group,\n    at::Tensor data_col);\n\nvoid modulated_deformable_col2im_cuda(\n    const at::Tensor data_col, const at::Tensor data_offset,\n    const at::Tensor data_mask, const int batch_size, const int channels,\n    const int height_im, const int width_im, const int height_col,\n    const int width_col, const int kernel_h, const int kenerl_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w, const int deformable_group,\n    at::Tensor grad_im);\n\nvoid modulated_deformable_col2im_coord_cuda(\n    const at::Tensor data_col, const at::Tensor data_im,\n    const at::Tensor data_offset, const at::Tensor data_mask,\n    const int batch_size, const int channels, const int height_im,\n    const int width_im, const int height_col, const int width_col,\n    const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,\n    const int stride_h, const int stride_w, const int dilation_h,\n    const int dilation_w, const int deformable_group, at::Tensor grad_offset,\n    at::Tensor grad_mask);\n\nvoid shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,\n                 at::Tensor weight, int kH, int kW, int dH, int dW, int padH,\n                 int padW, int dilationH, int dilationW, int group,\n                 int deformable_group) {\n  AT_CHECK(weight.ndimension() == 4,\n           \"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, \"\n           \"but got: %s\",\n           weight.ndimension());\n\n  AT_CHECK(weight.is_contiguous(), \"weight tensor has to be contiguous\");\n\n  AT_CHECK(kW > 0 && kH > 0,\n           \"kernel size should be greater than zero, but got kH: %d kW: %d\", kH,\n           kW);\n\n  AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),\n           \"kernel size should be consistent with weight, \",\n           \"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d\", kH,\n           kW, weight.size(2), weight.size(3));\n\n  AT_CHECK(dW > 0 && dH > 0,\n           \"stride should be greater than zero, but got dH: %d dW: %d\", dH, dW);\n\n  AT_CHECK(\n      dilationW > 0 && dilationH > 0,\n      \"dilation should be greater than 0, but got dilationH: %d dilationW: %d\",\n      dilationH, dilationW);\n\n  int ndim = input.ndimension();\n  int dimf = 0;\n  int dimh = 1;\n  int dimw = 2;\n\n  if (ndim == 4) {\n    dimf++;\n    dimh++;\n    dimw++;\n  }\n\n  AT_CHECK(ndim == 3 || ndim == 4, \"3D or 4D input tensor expected but got: %s\",\n           ndim);\n\n  long nInputPlane = weight.size(1) * group;\n  long inputHeight = input.size(dimh);\n  long inputWidth = input.size(dimw);\n  long nOutputPlane = weight.size(0);\n  long outputHeight =\n      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n  long outputWidth =\n      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n\n  AT_CHECK(nInputPlane % deformable_group == 0,\n           \"input channels must divide deformable group size\");\n\n  if (outputWidth < 1 || outputHeight < 1)\n    AT_ERROR(\n        \"Given input size: (%ld x %ld x %ld). \"\n        \"Calculated output size: (%ld x %ld x %ld). Output size is too small\",\n        nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,\n        outputWidth);\n\n  AT_CHECK(input.size(1) == nInputPlane,\n           \"invalid number of input planes, expected: %d, but got: %d\",\n           nInputPlane, input.size(1));\n\n  AT_CHECK((inputHeight >= kH && inputWidth >= kW),\n           \"input image is smaller than kernel\");\n\n  AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),\n           \"invalid spatial size of offset, expected height: %d width: %d, but \"\n           \"got height: %d width: %d\",\n           outputHeight, outputWidth, offset.size(2), offset.size(3));\n\n  AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),\n           \"invalid number of channels of offset\");\n\n  if (gradOutput != NULL) {\n    AT_CHECK(gradOutput->size(dimf) == nOutputPlane,\n             \"invalid number of gradOutput planes, expected: %d, but got: %d\",\n             nOutputPlane, gradOutput->size(dimf));\n\n    AT_CHECK((gradOutput->size(dimh) == outputHeight &&\n              gradOutput->size(dimw) == outputWidth),\n             \"invalid size of gradOutput, expected height: %d width: %d , but \"\n             \"got height: %d width: %d\",\n             outputHeight, outputWidth, gradOutput->size(dimh),\n             gradOutput->size(dimw));\n  }\n}\n\nint deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,\n                             at::Tensor offset, at::Tensor output,\n                             at::Tensor columns, at::Tensor ones, int kW,\n                             int kH, int dW, int dH, int padW, int padH,\n                             int dilationW, int dilationH, int group,\n                             int deformable_group, int im2col_step) {\n  // todo: resize columns to include im2col: done\n  // todo: add im2col_step as input\n  // todo: add new output buffer and transpose it to output (or directly\n  // transpose output) todo: possibly change data indexing because of\n  // parallel_imgs\n\n  shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,\n              dilationH, dilationW, group, deformable_group);\n\n  input = input.contiguous();\n  offset = offset.contiguous();\n  weight = weight.contiguous();\n\n  int batch = 1;\n  if (input.ndimension() == 3) {\n    // Force batch\n    batch = 0;\n    input.unsqueeze_(0);\n    offset.unsqueeze_(0);\n  }\n\n  // todo: assert batchsize dividable by im2col_step\n\n  long batchSize = input.size(0);\n  long nInputPlane = input.size(1);\n  long inputHeight = input.size(2);\n  long inputWidth = input.size(3);\n\n  long nOutputPlane = weight.size(0);\n\n  long outputWidth =\n      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  long outputHeight =\n      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n  AT_CHECK((offset.size(0) == batchSize), \"invalid batch size of offset\");\n\n  output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,\n                        outputHeight, outputWidth});\n  columns = at::zeros(\n      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},\n      input.options());\n\n  if (ones.ndimension() != 2 ||\n      ones.size(0) * ones.size(1) < outputHeight * outputWidth) {\n    ones = at::ones({outputHeight, outputWidth}, input.options());\n  }\n\n  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,\n                      inputHeight, inputWidth});\n  offset =\n      offset.view({batchSize / im2col_step, im2col_step,\n                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  at::Tensor output_buffer =\n      at::zeros({batchSize / im2col_step, nOutputPlane,\n                 im2col_step * outputHeight, outputWidth},\n                output.options());\n\n  output_buffer = output_buffer.view(\n      {output_buffer.size(0), group, output_buffer.size(1) / group,\n       output_buffer.size(2), output_buffer.size(3)});\n\n  for (int elt = 0; elt < batchSize / im2col_step; elt++) {\n    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,\n                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,\n                      dilationW, im2col_step, deformable_group, columns);\n\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n    weight = weight.view({group, weight.size(0) / group, weight.size(1),\n                          weight.size(2), weight.size(3)});\n\n    for (int g = 0; g < group; g++) {\n      output_buffer[elt][g] = output_buffer[elt][g]\n                                  .flatten(1)\n                                  .addmm_(weight[g].flatten(1), columns[g])\n                                  .view_as(output_buffer[elt][g]);\n    }\n  }\n\n  output_buffer = output_buffer.view(\n      {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),\n       output_buffer.size(3), output_buffer.size(4)});\n\n  output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,\n                                      im2col_step, outputHeight, outputWidth});\n  output_buffer.transpose_(1, 2);\n  output.copy_(output_buffer);\n  output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});\n\n  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});\n  offset = offset.view(\n      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  if (batch == 0) {\n    output = output.view({nOutputPlane, outputHeight, outputWidth});\n    input = input.view({nInputPlane, inputHeight, inputWidth});\n    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});\n  }\n\n  return 1;\n}\n\nint deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,\n                                    at::Tensor gradOutput, at::Tensor gradInput,\n                                    at::Tensor gradOffset, at::Tensor weight,\n                                    at::Tensor columns, int kW, int kH, int dW,\n                                    int dH, int padW, int padH, int dilationW,\n                                    int dilationH, int group,\n                                    int deformable_group, int im2col_step) {\n  shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,\n              dilationH, dilationW, group, deformable_group);\n\n  input = input.contiguous();\n  offset = offset.contiguous();\n  gradOutput = gradOutput.contiguous();\n  weight = weight.contiguous();\n\n  int batch = 1;\n\n  if (input.ndimension() == 3) {\n    // Force batch\n    batch = 0;\n    input = input.view({1, input.size(0), input.size(1), input.size(2)});\n    offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});\n    gradOutput = gradOutput.view(\n        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});\n  }\n\n  long batchSize = input.size(0);\n  long nInputPlane = input.size(1);\n  long inputHeight = input.size(2);\n  long inputWidth = input.size(3);\n\n  long nOutputPlane = weight.size(0);\n\n  long outputWidth =\n      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  long outputHeight =\n      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n  AT_CHECK((offset.size(0) == batchSize), 3, \"invalid batch size of offset\");\n  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});\n  columns = at::zeros(\n      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},\n      input.options());\n\n  // change order of grad output\n  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,\n                                nOutputPlane, outputHeight, outputWidth});\n  gradOutput.transpose_(1, 2);\n\n  gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,\n                              inputHeight, inputWidth});\n  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,\n                      inputHeight, inputWidth});\n  gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,\n                                deformable_group * 2 * kH * kW, outputHeight,\n                                outputWidth});\n  offset =\n      offset.view({batchSize / im2col_step, im2col_step,\n                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  for (int elt = 0; elt < batchSize / im2col_step; elt++) {\n    // divide into groups\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n    weight = weight.view({group, weight.size(0) / group, weight.size(1),\n                          weight.size(2), weight.size(3)});\n    gradOutput = gradOutput.view(\n        {gradOutput.size(0), group, gradOutput.size(1) / group,\n         gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});\n\n    for (int g = 0; g < group; g++) {\n      columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),\n                                     gradOutput[elt][g].flatten(1), 0.0f, 1.0f);\n    }\n\n    columns =\n        columns.view({columns.size(0) * columns.size(1), columns.size(2)});\n    gradOutput = gradOutput.view(\n        {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),\n         gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});\n\n    deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,\n                            inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,\n                            dilationH, dilationW, im2col_step, deformable_group,\n                            gradOffset[elt]);\n\n    deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,\n                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,\n                      dilationW, im2col_step, deformable_group, gradInput[elt]);\n  }\n\n  gradOutput.transpose_(1, 2);\n  gradOutput =\n      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});\n\n  gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});\n  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});\n  gradOffset = gradOffset.view(\n      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n  offset = offset.view(\n      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  if (batch == 0) {\n    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});\n    input = input.view({nInputPlane, inputHeight, inputWidth});\n    gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});\n    offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});\n    gradOffset =\n        gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});\n  }\n\n  return 1;\n}\n\nint deform_conv_backward_parameters_cuda(\n    at::Tensor input, at::Tensor offset, at::Tensor gradOutput,\n    at::Tensor gradWeight,  // at::Tensor gradBias,\n    at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,\n    int padW, int padH, int dilationW, int dilationH, int group,\n    int deformable_group, float scale, int im2col_step) {\n  // todo: transpose and reshape outGrad\n  // todo: reshape columns\n  // todo: add im2col_step as input\n\n  shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,\n              padW, dilationH, dilationW, group, deformable_group);\n\n  input = input.contiguous();\n  offset = offset.contiguous();\n  gradOutput = gradOutput.contiguous();\n\n  int batch = 1;\n\n  if (input.ndimension() == 3) {\n    // Force batch\n    batch = 0;\n    input = input.view(\n        at::IntList({1, input.size(0), input.size(1), input.size(2)}));\n    gradOutput = gradOutput.view(\n        {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});\n  }\n\n  long batchSize = input.size(0);\n  long nInputPlane = input.size(1);\n  long inputHeight = input.size(2);\n  long inputWidth = input.size(3);\n\n  long nOutputPlane = gradWeight.size(0);\n\n  long outputWidth =\n      (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;\n  long outputHeight =\n      (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;\n\n  AT_CHECK((offset.size(0) == batchSize), \"invalid batch size of offset\");\n\n  columns = at::zeros(\n      {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},\n      input.options());\n\n  gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,\n                                nOutputPlane, outputHeight, outputWidth});\n  gradOutput.transpose_(1, 2);\n\n  at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);\n  gradOutputBuffer =\n      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,\n                             outputHeight, outputWidth});\n  gradOutputBuffer.copy_(gradOutput);\n  gradOutputBuffer =\n      gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,\n                             im2col_step * outputHeight, outputWidth});\n\n  gradOutput.transpose_(1, 2);\n  gradOutput =\n      gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});\n\n  input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,\n                      inputHeight, inputWidth});\n  offset =\n      offset.view({batchSize / im2col_step, im2col_step,\n                   deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  for (int elt = 0; elt < batchSize / im2col_step; elt++) {\n    deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,\n                      inputWidth, kH, kW, padH, padW, dH, dW, dilationH,\n                      dilationW, im2col_step, deformable_group, columns);\n\n    // divide into group\n    gradOutputBuffer = gradOutputBuffer.view(\n        {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,\n         gradOutputBuffer.size(2), gradOutputBuffer.size(3)});\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n    gradWeight =\n        gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),\n                         gradWeight.size(2), gradWeight.size(3)});\n\n    for (int g = 0; g < group; g++) {\n      gradWeight[g] = gradWeight[g]\n                          .flatten(1)\n                          .addmm_(gradOutputBuffer[elt][g].flatten(1),\n                                  columns[g].transpose(1, 0), 1.0, scale)\n                          .view_as(gradWeight[g]);\n    }\n    gradOutputBuffer = gradOutputBuffer.view(\n        {gradOutputBuffer.size(0),\n         gradOutputBuffer.size(1) * gradOutputBuffer.size(2),\n         gradOutputBuffer.size(3), gradOutputBuffer.size(4)});\n    columns =\n        columns.view({columns.size(0) * columns.size(1), columns.size(2)});\n    gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),\n                                  gradWeight.size(2), gradWeight.size(3),\n                                  gradWeight.size(4)});\n  }\n\n  input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});\n  offset = offset.view(\n      {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});\n\n  if (batch == 0) {\n    gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});\n    input = input.view({nInputPlane, inputHeight, inputWidth});\n  }\n\n  return 1;\n}\n\nvoid modulated_deform_conv_cuda_forward(\n    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,\n    at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,\n    int kernel_h, int kernel_w, const int stride_h, const int stride_w,\n    const int pad_h, const int pad_w, const int dilation_h,\n    const int dilation_w, const int group, const int deformable_group,\n    const bool with_bias) {\n  AT_CHECK(input.is_contiguous(), \"input tensor has to be contiguous\");\n  AT_CHECK(weight.is_contiguous(), \"weight tensor has to be contiguous\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n\n  const int channels_out = weight.size(0);\n  const int channels_kernel = weight.size(1);\n  const int kernel_h_ = weight.size(2);\n  const int kernel_w_ = weight.size(3);\n\n  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)\n    AT_ERROR(\"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\",\n             kernel_h_, kernel_w, kernel_h_, kernel_w_);\n  if (channels != channels_kernel * group)\n    AT_ERROR(\"Input shape and kernel channels wont match: (%d vs %d).\",\n             channels, channels_kernel * group);\n\n  const int height_out =\n      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n  const int width_out =\n      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n  if (ones.ndimension() != 2 ||\n      ones.size(0) * ones.size(1) < height_out * width_out) {\n    // Resize plane and fill with ones...\n    ones = at::ones({height_out, width_out}, input.options());\n  }\n\n  // resize output\n  output = output.view({batch, channels_out, height_out, width_out}).zero_();\n  // resize temporary columns\n  columns =\n      at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},\n                input.options());\n\n  output = output.view({output.size(0), group, output.size(1) / group,\n                        output.size(2), output.size(3)});\n\n  for (int b = 0; b < batch; b++) {\n    modulated_deformable_im2col_cuda(\n        input[b], offset[b], mask[b], 1, channels, height, width, height_out,\n        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, deformable_group, columns);\n\n    // divide into group\n    weight = weight.view({group, weight.size(0) / group, weight.size(1),\n                          weight.size(2), weight.size(3)});\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n\n    for (int g = 0; g < group; g++) {\n      output[b][g] = output[b][g]\n                         .flatten(1)\n                         .addmm_(weight[g].flatten(1), columns[g])\n                         .view_as(output[b][g]);\n    }\n\n    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),\n                          weight.size(3), weight.size(4)});\n    columns =\n        columns.view({columns.size(0) * columns.size(1), columns.size(2)});\n  }\n\n  output = output.view({output.size(0), output.size(1) * output.size(2),\n                        output.size(3), output.size(4)});\n\n  if (with_bias) {\n    output += bias.view({1, bias.size(0), 1, 1});\n  }\n}\n\nvoid modulated_deform_conv_cuda_backward(\n    at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,\n    at::Tensor offset, at::Tensor mask, at::Tensor columns,\n    at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,\n    at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,\n    int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,\n    int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,\n    const bool with_bias) {\n  AT_CHECK(input.is_contiguous(), \"input tensor has to be contiguous\");\n  AT_CHECK(weight.is_contiguous(), \"weight tensor has to be contiguous\");\n\n  const int batch = input.size(0);\n  const int channels = input.size(1);\n  const int height = input.size(2);\n  const int width = input.size(3);\n\n  const int channels_kernel = weight.size(1);\n  const int kernel_h_ = weight.size(2);\n  const int kernel_w_ = weight.size(3);\n  if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)\n    AT_ERROR(\"Input shape and kernel shape wont match: (%d x %d vs %d x %d).\",\n             kernel_h_, kernel_w, kernel_h_, kernel_w_);\n  if (channels != channels_kernel * group)\n    AT_ERROR(\"Input shape and kernel channels wont match: (%d vs %d).\",\n             channels, channels_kernel * group);\n\n  const int height_out =\n      (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;\n  const int width_out =\n      (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;\n\n  if (ones.ndimension() != 2 ||\n      ones.size(0) * ones.size(1) < height_out * width_out) {\n    // Resize plane and fill with ones...\n    ones = at::ones({height_out, width_out}, input.options());\n  }\n\n  grad_input = grad_input.view({batch, channels, height, width});\n  columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},\n                      input.options());\n\n  grad_output =\n      grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,\n                        grad_output.size(2), grad_output.size(3)});\n\n  for (int b = 0; b < batch; b++) {\n    // divide int group\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n    weight = weight.view({group, weight.size(0) / group, weight.size(1),\n                          weight.size(2), weight.size(3)});\n\n    for (int g = 0; g < group; g++) {\n      columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),\n                        grad_output[b][g].flatten(1), 0.0f, 1.0f);\n    }\n\n    columns =\n        columns.view({columns.size(0) * columns.size(1), columns.size(2)});\n    weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),\n                          weight.size(3), weight.size(4)});\n\n    // gradient w.r.t. input coordinate data\n    modulated_deformable_col2im_coord_cuda(\n        columns, input[b], offset[b], mask[b], 1, channels, height, width,\n        height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,\n        stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],\n        grad_mask[b]);\n    // gradient w.r.t. input data\n    modulated_deformable_col2im_cuda(\n        columns, offset[b], mask[b], 1, channels, height, width, height_out,\n        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, deformable_group, grad_input[b]);\n\n    // gradient w.r.t. weight, dWeight should accumulate across the batch and\n    // group\n    modulated_deformable_im2col_cuda(\n        input[b], offset[b], mask[b], 1, channels, height, width, height_out,\n        width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n        dilation_h, dilation_w, deformable_group, columns);\n\n    columns = columns.view({group, columns.size(0) / group, columns.size(1)});\n    grad_weight = grad_weight.view({group, grad_weight.size(0) / group,\n                                    grad_weight.size(1), grad_weight.size(2),\n                                    grad_weight.size(3)});\n    if (with_bias)\n      grad_bias = grad_bias.view({group, grad_bias.size(0) / group});\n\n    for (int g = 0; g < group; g++) {\n      grad_weight[g] =\n          grad_weight[g]\n              .flatten(1)\n              .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))\n              .view_as(grad_weight[g]);\n      if (with_bias) {\n        grad_bias[g] =\n            grad_bias[g]\n                .view({-1, 1})\n                .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))\n                .view(-1);\n      }\n    }\n\n    columns =\n        columns.view({columns.size(0) * columns.size(1), columns.size(2)});\n    grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),\n                                    grad_weight.size(2), grad_weight.size(3),\n                                    grad_weight.size(4)});\n    if (with_bias)\n      grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});\n  }\n  grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),\n                                  grad_output.size(2), grad_output.size(3),\n                                  grad_output.size(4)});\n}\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"deform_conv_forward_cuda\", &deform_conv_forward_cuda,\n        \"deform forward (CUDA)\");\n  m.def(\"deform_conv_backward_input_cuda\", &deform_conv_backward_input_cuda,\n        \"deform_conv_backward_input (CUDA)\");\n  m.def(\"deform_conv_backward_parameters_cuda\",\n        &deform_conv_backward_parameters_cuda,\n        \"deform_conv_backward_parameters (CUDA)\");\n  m.def(\"modulated_deform_conv_cuda_forward\",\n        &modulated_deform_conv_cuda_forward,\n        \"modulated deform conv forward (CUDA)\");\n  m.def(\"modulated_deform_conv_cuda_backward\",\n        &modulated_deform_conv_cuda_backward,\n        \"modulated deform conv backward (CUDA)\");\n}\n"
  },
  {
    "path": "src/model/dcn/src/deform_conv_cuda_kernel.cu",
    "content": "/*!\n ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************\n *\n * COPYRIGHT\n *\n * All contributions by the University of California:\n * Copyright (c) 2014-2017 The Regents of the University of California (Regents)\n * All rights reserved.\n *\n * All other contributions:\n * Copyright (c) 2014-2017, the respective contributors\n * All rights reserved.\n *\n * Caffe uses a shared copyright model: each contributor holds copyright over\n * their contributions to Caffe. The project versioning records all such\n * contribution and copyright details. If a contributor wants to further mark\n * their specific copyright on a particular contribution, they should indicate\n * their copyright solely in the commit message of the change when it is\n * committed.\n *\n * LICENSE\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions are met:\n *\n * 1. Redistributions of source code must retain the above copyright notice, this\n * list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright notice,\n * this list of conditions and the following disclaimer in the documentation\n * and/or other materials provided with the distribution.\n *\n * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\n * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR\n * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND\n * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\n * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n *\n * CONTRIBUTION AGREEMENT\n *\n * By contributing to the BVLC/caffe repository through pull-request, comment,\n * or otherwise, the contributor releases their content to the\n * license and copyright terms herein.\n *\n ***************** END Caffe Copyright Notice and Disclaimer ********************\n *\n * Copyright (c) 2018 Microsoft\n * Licensed under The MIT License [see LICENSE for details]\n * \\file modulated_deformable_im2col.cuh\n * \\brief Function definitions of converting an image to\n * column matrix based on kernel, padding, dilation, and offset.\n * These functions are mainly used in deformable convolution operators.\n * \\ref: https://arxiv.org/abs/1703.06211\n * \\author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng\n */\n\n// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu\n\n#include <ATen/ATen.h>\n#include <THC/THCAtomics.cuh>\n#include <stdio.h>\n#include <math.h>\n#include <float.h>\n\nusing namespace at;\n\n#define CUDA_KERNEL_LOOP(i, n)                                 \\\n  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \\\n       i += blockDim.x * gridDim.x)\n\nconst int CUDA_NUM_THREADS = 1024;\nconst int kMaxGridNum = 65535;\n\ninline int GET_BLOCKS(const int N)\n{\n  return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,\n                                               const int height, const int width, scalar_t h, scalar_t w)\n{\n\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  scalar_t lh = h - h_low;\n  scalar_t lw = w - w_low;\n  scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,\n                                        const int h, const int w, const int height, const int width)\n{\n\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  scalar_t weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,\n                                          const int height, const int width, const scalar_t *im_data,\n                                          const int data_width, const int bp_dir)\n{\n\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  scalar_t weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\ntemplate <typename scalar_t>\n__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,\n                                             const int height, const int width, const int kernel_h, const int kernel_w,\n                                             const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n                                             const int dilation_h, const int dilation_w, const int channel_per_deformable_group,\n                                             const int batch_size, const int num_channels, const int deformable_group,\n                                             const int height_col, const int width_col,\n                                             scalar_t *data_col)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    const int b_col = (index / width_col / height_col) % batch_size;\n    const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n        scalar_t val = static_cast<scalar_t>(0);\n        const scalar_t h_im = h_in + i * dilation_h + offset_h;\n        const scalar_t w_im = w_in + j * dilation_w + offset_w;\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const scalar_t map_h = i * dilation_h + offset_h;\n          //const scalar_t map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val;\n        data_col_ptr += batch_size * height_col * width_col;\n      }\n    }\n  }\n}\n\nvoid deformable_im2col(\n    const at::Tensor data_im, const at::Tensor data_offset, const int channels,\n    const int height, const int width, const int ksize_h, const int ksize_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w, const int parallel_imgs,\n    const int deformable_group, at::Tensor data_col)\n{\n  // num_axes should be smaller than block size\n  // todo: check parallel_imgs is correctly passed in\n  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;\n  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;\n  int num_kernels = channels * height_col * width_col * parallel_imgs;\n  int channel_per_deformable_group = channels / deformable_group;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_im.scalar_type(), \"deformable_im2col_gpu\", ([&] {\n        const scalar_t *data_im_ = data_im.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        scalar_t *data_col_ = data_col.data<scalar_t>();\n\n        deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,\n            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,\n            channel_per_deformable_group, parallel_imgs, channels, deformable_group,\n            height_col, width_col, data_col_);\n      }));\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in deformable_im2col: %s\\n\", cudaGetErrorString(err));\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void deformable_col2im_gpu_kernel(\n    const int n, const scalar_t *data_col, const scalar_t *data_offset,\n    const int channels, const int height, const int width,\n    const int kernel_h, const int kernel_w,\n    const int pad_h, const int pad_w,\n    const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w,\n    const int channel_per_deformable_group,\n    const int batch_size, const int deformable_group,\n    const int height_col, const int width_col,\n    scalar_t *grad_im)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *\n                                                        2 * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const scalar_t cur_top_grad = data_col[index];\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n        }\n      }\n    }\n  }\n}\n\nvoid deformable_col2im(\n    const at::Tensor data_col, const at::Tensor data_offset, const int channels,\n    const int height, const int width, const int ksize_h,\n    const int ksize_w, const int pad_h, const int pad_w,\n    const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w,\n    const int parallel_imgs, const int deformable_group,\n    at::Tensor grad_im)\n{\n\n  // todo: make sure parallel_imgs is passed in correctly\n  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;\n  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;\n  int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;\n  int channel_per_deformable_group = channels / deformable_group;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_col.scalar_type(), \"deformable_col2im_gpu\", ([&] {\n        const scalar_t *data_col_ = data_col.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        scalar_t *grad_im_ = grad_im.data<scalar_t>();\n\n        deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,\n            ksize_w, pad_h, pad_w, stride_h, stride_w,\n            dilation_h, dilation_w, channel_per_deformable_group,\n            parallel_imgs, deformable_group, height_col, width_col, grad_im_);\n      }));\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in deformable_col2im: %s\\n\", cudaGetErrorString(err));\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,\n                                                   const scalar_t *data_im, const scalar_t *data_offset,\n                                                   const int channels, const int height, const int width,\n                                                   const int kernel_h, const int kernel_w,\n                                                   const int pad_h, const int pad_w,\n                                                   const int stride_h, const int stride_w,\n                                                   const int dilation_h, const int dilation_w,\n                                                   const int channel_per_deformable_group,\n                                                   const int batch_size, const int offset_channels, const int deformable_group,\n                                                   const int height_col, const int width_col, scalar_t *grad_offset)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    scalar_t val = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *\n                                                  batch_size * width_col * height_col;\n    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *\n                                                channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *\n                                                        kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n      scalar_t inv_h = h_in + i * dilation_h + offset_h;\n      scalar_t inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      const scalar_t weight = get_coordinate_weight(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos];\n      cnt += 1;\n    }\n\n    grad_offset[index] = val;\n  }\n}\n\nvoid deformable_col2im_coord(\n    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,\n    const int channels, const int height, const int width, const int ksize_h,\n    const int ksize_w, const int pad_h, const int pad_w, const int stride_h,\n    const int stride_w, const int dilation_h, const int dilation_w,\n    const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)\n{\n\n  int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;\n  int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;\n  int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;\n  int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_col.scalar_type(), \"deformable_col2im_coord_gpu\", ([&] {\n        const scalar_t *data_col_ = data_col.data<scalar_t>();\n        const scalar_t *data_im_ = data_im.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        scalar_t *grad_offset_ = grad_offset.data<scalar_t>();\n\n        deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_col_, data_im_, data_offset_, channels, height, width,\n            ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,\n            dilation_h, dilation_w, channel_per_deformable_group,\n            parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,\n            height_col, width_col, grad_offset_);\n      }));\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,\n                                         const int height, const int width, scalar_t h, scalar_t w)\n{\n  int h_low = floor(h);\n  int w_low = floor(w);\n  int h_high = h_low + 1;\n  int w_high = w_low + 1;\n\n  scalar_t lh = h - h_low;\n  scalar_t lw = w - w_low;\n  scalar_t hh = 1 - lh, hw = 1 - lw;\n\n  scalar_t v1 = 0;\n  if (h_low >= 0 && w_low >= 0)\n    v1 = bottom_data[h_low * data_width + w_low];\n  scalar_t v2 = 0;\n  if (h_low >= 0 && w_high <= width - 1)\n    v2 = bottom_data[h_low * data_width + w_high];\n  scalar_t v3 = 0;\n  if (h_high <= height - 1 && w_low >= 0)\n    v3 = bottom_data[h_high * data_width + w_low];\n  scalar_t v4 = 0;\n  if (h_high <= height - 1 && w_high <= width - 1)\n    v4 = bottom_data[h_high * data_width + w_high];\n\n  scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;\n\n  scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);\n  return val;\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,\n                                             const int h, const int w, const int height, const int width)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  scalar_t weight = 0;\n  if (h == argmax_h_low && w == argmax_w_low)\n    weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);\n  if (h == argmax_h_low && w == argmax_w_high)\n    weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);\n  if (h == argmax_h_high && w == argmax_w_low)\n    weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);\n  if (h == argmax_h_high && w == argmax_w_high)\n    weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);\n  return weight;\n}\n\ntemplate <typename scalar_t>\n__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,\n                                               const int height, const int width, const scalar_t *im_data,\n                                               const int data_width, const int bp_dir)\n{\n  if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)\n  {\n    //empty\n    return 0;\n  }\n\n  int argmax_h_low = floor(argmax_h);\n  int argmax_w_low = floor(argmax_w);\n  int argmax_h_high = argmax_h_low + 1;\n  int argmax_w_high = argmax_w_low + 1;\n\n  scalar_t weight = 0;\n\n  if (bp_dir == 0)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n  else if (bp_dir == 1)\n  {\n    if (argmax_h_low >= 0 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];\n    if (argmax_h_low >= 0 && argmax_w_high <= width - 1)\n      weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];\n    if (argmax_h_high <= height - 1 && argmax_w_low >= 0)\n      weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];\n    if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)\n      weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];\n  }\n\n  return weight;\n}\n\ntemplate <typename scalar_t>\n__global__ void modulated_deformable_im2col_gpu_kernel(const int n,\n                                                       const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,\n                                                       const int height, const int width, const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int num_channels, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       scalar_t *data_col)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    // index index of output matrix\n    const int w_col = index % width_col;\n    const int h_col = (index / width_col) % height_col;\n    const int b_col = (index / width_col / height_col) % batch_size;\n    const int c_im = (index / width_col / height_col) / batch_size;\n    const int c_col = c_im * kernel_h * kernel_w;\n\n    // compute deformable group index\n    const int deformable_group_index = c_im / channel_per_deformable_group;\n\n    const int h_in = h_col * stride_h - pad_h;\n    const int w_in = w_col * stride_w - pad_w;\n\n    scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;\n    //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;\n    const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;\n    const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n\n    const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    for (int i = 0; i < kernel_h; ++i)\n    {\n      for (int j = 0; j < kernel_w; ++j)\n      {\n        const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;\n        const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;\n        const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;\n        const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n        const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n        const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];\n        scalar_t val = static_cast<scalar_t>(0);\n        const scalar_t h_im = h_in + i * dilation_h + offset_h;\n        const scalar_t w_im = w_in + j * dilation_w + offset_w;\n        //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {\n        if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)\n        {\n          //const float map_h = i * dilation_h + offset_h;\n          //const float map_w = j * dilation_w + offset_w;\n          //const int cur_height = height - h_in;\n          //const int cur_width = width - w_in;\n          //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);\n          val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);\n        }\n        *data_col_ptr = val * mask;\n        data_col_ptr += batch_size * height_col * width_col;\n        //data_col_ptr += height_col * width_col;\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void modulated_deformable_col2im_gpu_kernel(const int n,\n                                                       const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,\n                                                       const int channels, const int height, const int width,\n                                                       const int kernel_h, const int kernel_w,\n                                                       const int pad_h, const int pad_w,\n                                                       const int stride_h, const int stride_w,\n                                                       const int dilation_h, const int dilation_w,\n                                                       const int channel_per_deformable_group,\n                                                       const int batch_size, const int deformable_group,\n                                                       const int height_col, const int width_col,\n                                                       scalar_t *grad_im)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    const int j = (index / width_col / height_col / batch_size) % kernel_w;\n    const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;\n    const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / channel_per_deformable_group;\n\n    int w_out = index % width_col;\n    int h_out = (index / width_col) % height_col;\n    int b = (index / width_col / height_col) % batch_size;\n    int w_in = w_out * stride_w - pad_w;\n    int h_in = h_out * stride_h - pad_h;\n\n    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n    const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;\n    const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;\n    const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;\n    const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n    const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n    const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];\n    const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;\n    const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;\n\n    const scalar_t cur_top_grad = data_col[index] * mask;\n    const int cur_h = (int)cur_inv_h_data;\n    const int cur_w = (int)cur_inv_w_data;\n    for (int dy = -2; dy <= 2; dy++)\n    {\n      for (int dx = -2; dx <= 2; dx++)\n      {\n        if (cur_h + dy >= 0 && cur_h + dy < height &&\n            cur_w + dx >= 0 && cur_w + dx < width &&\n            abs(cur_inv_h_data - (cur_h + dy)) < 1 &&\n            abs(cur_inv_w_data - (cur_w + dx)) < 1)\n        {\n          int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;\n          scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);\n          atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);\n        }\n      }\n    }\n  }\n}\n\ntemplate <typename scalar_t>\n__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,\n                                                             const scalar_t *data_col, const scalar_t *data_im,\n                                                             const scalar_t *data_offset, const scalar_t *data_mask,\n                                                             const int channels, const int height, const int width,\n                                                             const int kernel_h, const int kernel_w,\n                                                             const int pad_h, const int pad_w,\n                                                             const int stride_h, const int stride_w,\n                                                             const int dilation_h, const int dilation_w,\n                                                             const int channel_per_deformable_group,\n                                                             const int batch_size, const int offset_channels, const int deformable_group,\n                                                             const int height_col, const int width_col,\n                                                             scalar_t *grad_offset, scalar_t *grad_mask)\n{\n  CUDA_KERNEL_LOOP(index, n)\n  {\n    scalar_t val = 0, mval = 0;\n    int w = index % width_col;\n    int h = (index / width_col) % height_col;\n    int c = (index / width_col / height_col) % offset_channels;\n    int b = (index / width_col / height_col) / offset_channels;\n    // compute the start and end of the output\n\n    const int deformable_group_index = c / (2 * kernel_h * kernel_w);\n    const int col_step = kernel_h * kernel_w;\n    int cnt = 0;\n    const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;\n    const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;\n    const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;\n    const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;\n\n    const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;\n\n    for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)\n    {\n      const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;\n      const int bp_dir = offset_c % 2;\n\n      int j = (col_pos / width_col / height_col / batch_size) % kernel_w;\n      int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;\n      int w_out = col_pos % width_col;\n      int h_out = (col_pos / width_col) % height_col;\n      int w_in = w_out * stride_w - pad_w;\n      int h_in = h_out * stride_h - pad_h;\n      const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);\n      const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);\n      const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);\n      const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];\n      const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];\n      const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];\n      scalar_t inv_h = h_in + i * dilation_h + offset_h;\n      scalar_t inv_w = w_in + j * dilation_w + offset_w;\n      if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)\n      {\n        inv_h = inv_w = -2;\n      }\n      else\n      {\n        mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);\n      }\n      const scalar_t weight = dmcn_get_coordinate_weight(\n          inv_h, inv_w,\n          height, width, data_im_ptr + cnt * height * width, width, bp_dir);\n      val += weight * data_col_ptr[col_pos] * mask;\n      cnt += 1;\n    }\n    // KERNEL_ASSIGN(grad_offset[index], offset_req, val);\n    grad_offset[index] = val;\n    if (offset_c % 2 == 0)\n      // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);\n      grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;\n  }\n}\n\nvoid modulated_deformable_im2col_cuda(\n    const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,\n    const int batch_size, const int channels, const int height_im, const int width_im,\n    const int height_col, const int width_col, const int kernel_h, const int kenerl_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w,\n    const int deformable_group, at::Tensor data_col)\n{\n  // num_axes should be smaller than block size\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * batch_size * height_col * width_col;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_im.scalar_type(), \"modulated_deformable_im2col_gpu\", ([&] {\n        const scalar_t *data_im_ = data_im.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        const scalar_t *data_mask_ = data_mask.data<scalar_t>();\n        scalar_t *data_col_ = data_col.data<scalar_t>();\n\n        modulated_deformable_im2col_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,\n            pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,\n            batch_size, channels, deformable_group, height_col, width_col, data_col_);\n      }));\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_im2col_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n}\n\nvoid modulated_deformable_col2im_cuda(\n    const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,\n    const int batch_size, const int channels, const int height_im, const int width_im,\n    const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w,\n    const int deformable_group, at::Tensor grad_im)\n{\n\n  const int channel_per_deformable_group = channels / deformable_group;\n  const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_col.scalar_type(), \"modulated_deformable_col2im_gpu\", ([&] {\n        const scalar_t *data_col_ = data_col.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        const scalar_t *data_mask_ = data_mask.data<scalar_t>();\n        scalar_t *grad_im_ = grad_im.data<scalar_t>();\n\n        modulated_deformable_col2im_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,\n            kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w,\n            dilation_h, dilation_w, channel_per_deformable_group,\n            batch_size, deformable_group, height_col, width_col, grad_im_);\n      }));\n\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n}\n\nvoid modulated_deformable_col2im_coord_cuda(\n    const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,\n    const int batch_size, const int channels, const int height_im, const int width_im,\n    const int height_col, const int width_col, const int kernel_h, const int kernel_w,\n    const int pad_h, const int pad_w, const int stride_h, const int stride_w,\n    const int dilation_h, const int dilation_w,\n    const int deformable_group,\n    at::Tensor grad_offset, at::Tensor grad_mask)\n{\n  const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;\n  const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;\n\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(\n      data_col.scalar_type(), \"modulated_deformable_col2im_coord_gpu\", ([&] {\n        const scalar_t *data_col_ = data_col.data<scalar_t>();\n        const scalar_t *data_im_ = data_im.data<scalar_t>();\n        const scalar_t *data_offset_ = data_offset.data<scalar_t>();\n        const scalar_t *data_mask_ = data_mask.data<scalar_t>();\n        scalar_t *grad_offset_ = grad_offset.data<scalar_t>();\n        scalar_t *grad_mask_ = grad_mask.data<scalar_t>();\n\n        modulated_deformable_col2im_coord_gpu_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>>(\n            num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,\n            kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,\n            dilation_h, dilation_w, channel_per_deformable_group,\n            batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,\n            grad_offset_, grad_mask_);\n      }));\n  cudaError_t err = cudaGetLastError();\n  if (err != cudaSuccess)\n  {\n    printf(\"error in modulated_deformable_col2im_coord_cuda: %s\\n\", cudaGetErrorString(err));\n  }\n}\n"
  },
  {
    "path": "src/model/ddbpn.py",
    "content": "# Deep Back-Projection Networks For Super-Resolution\n# https://arxiv.org/abs/1803.02735\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return DDBPN(args)\n\ndef projection_conv(in_channels, out_channels, scale, up=True):\n    kernel_size, stride, padding = {\n        2: (6, 2, 2),\n        4: (8, 4, 2),\n        8: (12, 8, 2)\n    }[scale]\n    if up:\n        conv_f = nn.ConvTranspose2d\n    else:\n        conv_f = nn.Conv2d\n\n    return conv_f(\n        in_channels, out_channels, kernel_size,\n        stride=stride, padding=padding\n    )\n\nclass DenseProjection(nn.Module):\n    def __init__(self, in_channels, nr, scale, up=True, bottleneck=True):\n        super(DenseProjection, self).__init__()\n        if bottleneck:\n            self.bottleneck = nn.Sequential(*[\n                nn.Conv2d(in_channels, nr, 1),\n                nn.PReLU(nr)\n            ])\n            inter_channels = nr\n        else:\n            self.bottleneck = None\n            inter_channels = in_channels\n\n        self.conv_1 = nn.Sequential(*[\n            projection_conv(inter_channels, nr, scale, up),\n            nn.PReLU(nr)\n        ])\n        self.conv_2 = nn.Sequential(*[\n            projection_conv(nr, inter_channels, scale, not up),\n            nn.PReLU(inter_channels)\n        ])\n        self.conv_3 = nn.Sequential(*[\n            projection_conv(inter_channels, nr, scale, up),\n            nn.PReLU(nr)\n        ])\n\n    def forward(self, x):\n        if self.bottleneck is not None:\n            x = self.bottleneck(x)\n\n        a_0 = self.conv_1(x)\n        b_0 = self.conv_2(a_0)\n        e = b_0.sub(x)\n        a_1 = self.conv_3(e)\n\n        out = a_0.add(a_1)\n\n        return out\n\nclass DDBPN(nn.Module):\n    def __init__(self, args):\n        super(DDBPN, self).__init__()\n        scale = args.scale[0]\n\n        n0 = 128\n        nr = 32\n        self.depth = 6\n\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        initial = [\n            nn.Conv2d(args.n_colors, n0, 3, padding=1),\n            nn.PReLU(n0),\n            nn.Conv2d(n0, nr, 1),\n            nn.PReLU(nr)\n        ]\n        self.initial = nn.Sequential(*initial)\n\n        self.upmodules = nn.ModuleList()\n        self.downmodules = nn.ModuleList()\n        channels = nr\n        for i in range(self.depth):\n            self.upmodules.append(\n                DenseProjection(channels, nr, scale, True, i > 1)\n            )\n            if i != 0:\n                channels += nr\n        \n        channels = nr\n        for i in range(self.depth - 1):\n            self.downmodules.append(\n                DenseProjection(channels, nr, scale, False, i != 0)\n            )\n            channels += nr\n\n        reconstruction = [\n            nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) \n        ]\n        self.reconstruction = nn.Sequential(*reconstruction)\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.initial(x)\n\n        h_list = []\n        l_list = []\n        for i in range(self.depth - 1):\n            if i == 0:\n                l = x\n            else:\n                l = torch.cat(l_list, dim=1)\n            h_list.append(self.upmodules[i](l))\n            l_list.append(self.downmodules[i](torch.cat(h_list, dim=1)))\n        \n        h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1)))\n        out = self.reconstruction(torch.cat(h_list, dim=1))\n        out = self.add_mean(out)\n\n        return out\n\n"
  },
  {
    "path": "src/model/edsr.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nurl = {\n    'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',\n    'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',\n    'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',\n    'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',\n    'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',\n    'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'\n}\n\ndef make_model(args, parent=False):\n    return EDSR(args)\n\nclass EDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(EDSR, self).__init__()\n\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3 \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)\n        if url_name in url:\n            self.url = url[url_name]\n        else:\n            self.url = None\n        self.sub_mean = common.MeanShift(args.rgb_range)\n        self.add_mean = common.MeanShift(args.rgb_range, sign=1)\n\n        # define head module\n        m_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        m_body = [\n            common.ResBlock(\n                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale\n            ) for _ in range(n_resblocks)\n        ]\n        m_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        m_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)\n        ]\n\n        self.head = nn.Sequential(*m_head)\n        self.body = nn.Sequential(*m_body)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=True):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') == -1:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n"
  },
  {
    "path": "src/model/han.py",
    "content": "from model import common\nimport torch\nimport torch.nn as nn\nimport pdb\n\ndef make_model(args, parent=False):\n    return HAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\nclass LAM_Module(nn.Module):\n    \"\"\" Layer attention module\"\"\"\n    def __init__(self, in_dim):\n        super(LAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass CSAM_Module(nn.Module):\n    \"\"\" Channel-Spatial attention module\"\"\"\n    def __init__(self, in_dim):\n        super(CSAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.conv = nn.Conv3d(1, 1, 3, 1, 1)\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.softmax  = nn.Softmax(dim=-1)\n        self.sigmoid = nn.Sigmoid()\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        out = x.unsqueeze(1)\n        out = self.sigmoid(self.conv(out))\n        \n        # proj_query = x.view(m_batchsize, N, -1)\n        # proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        # energy = torch.bmm(proj_query, proj_key)\n        # energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        # attention = self.softmax(energy_new)\n        # proj_value = x.view(m_batchsize, N, -1)\n\n        # out = torch.bmm(attention, proj_value)\n        # out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out\n        out = out.view(m_batchsize, -1, height, width)\n        x = x * out + x\n        return x\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Holistic Attention Network (HAN)\nclass HAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(HAN, self).__init__()\n        \n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.csa = CSAM_Module(n_feats)\n        self.la = LAM_Module(n_feats)\n        self.last_conv = nn.Conv2d(n_feats*11, n_feats, 3, 1, 1)\n        self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        res = x\n        #pdb.set_trace()\n        for name, midlayer in self.body._modules.items():\n            res = midlayer(res)\n            #print(name)\n            if name=='0':\n                res1 = res.unsqueeze(1)\n            else:\n                res1 = torch.cat([res.unsqueeze(1),res1],1)\n        #res = self.body(x)\n        out1 = res\n        #res3 = res.unsqueeze(1)\n        #res = torch.cat([res1,res3],1)\n        res = self.la(res1)\n        out2 = self.last_conv(res)\n\n        out1 = self.csa(out1)\n        out = torch.cat([out1, out2], 1)\n        res = self.last(out)\n        \n        res += x\n        #res = self.csa(res)\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/matrixmodel.py",
    "content": "# ------------------------------------------------------------------------------\n# Copyright (c) Microsoft\n# Licensed under the MIT License.\n# Written by Bin Xiao (Bin.Xiao@microsoft.com)\n# ------------------------------------------------------------------------------\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport logging\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nimport torch.nn.functional as F\nfrom model import ops\nimport pdb\n\n\ntry:\n    from model.dcn.deform_conv import ModulatedDeformConvPack as DCN\nexcept ImportError:\n    raise ImportError('Failed to import DCNv2 module.')\n\nBN_MOMENTUM = 0.1\nlogger = logging.getLogger(__name__)\n\ndef initialize_weights(net_l, scale=1):\n    if not isinstance(net_l, list):\n        net_l = [net_l]\n    for net in net_l:\n        for m in net.modules():\n            if isinstance(m, nn.Conv2d):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale  # for residual block\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.Linear):\n                init.kaiming_normal_(m.weight, a=0, mode='fan_in')\n                m.weight.data *= scale\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                init.constant_(m.weight, 1)\n                init.constant_(m.bias.data, 0.0)\n\nclass ResBlock(nn.Module):\n    def __init__(\n        self, num_channels, kernel_size=3,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1,**kwargs):\n\n        super(ResBlock, self).__init__()\n        m = []\n        for i in range(2):\n            m.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=bias))\n            if bn: m.append(nn.BatchNorm2d(num_channels))\n            if i == 0: m.append(act)\n\n        self.body = nn.Sequential(*m)\n        self.res_scale = res_scale\n        initialize_weights([self.body], 0.1)\n\n    def forward(self, x):\n        res = self.body(x).mul(self.res_scale)\n        res += x\n\n        return res\n\nclass BFN(nn.Module):\n    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):\n        super(BFN, self).__init__()\n\n        branch1=[]\n        branch1.append(self._make_blocks(num_channels[0], num_channels[0], kernel_size, reduction, n_blocks, block))\n        branch1.append(nn.Conv2d(num_channels[0], num_channels[0], kernel_size, stride=1, padding=1, bias=True))\n        branch2=[]\n        branch2.append(self._make_blocks(num_channels[1], num_channels[1], kernel_size, reduction, n_blocks, block))\n        branch2.append(nn.Conv2d(num_channels[1], num_channels[1], kernel_size, stride=1, padding=1, bias=True))\n        branch3=[]\n        branch3.append(self._make_blocks(num_channels[2], num_channels[2], kernel_size, reduction, n_blocks, block))\n        branch3.append(nn.Conv2d(num_channels[2], num_channels[2], kernel_size, stride=1, padding=1, bias=True))\n        self.branch1 = nn.Sequential(*branch1)\n        self.branch2 = nn.Sequential(*branch2)\n        self.branch3 = nn.Sequential(*branch3)\n        #self.act=nn.ReLU(True)\n\n\n    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        blocks = []\n        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \\\n            for _ in range(n_blocks)]\n        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))\n        \n        return nn.Sequential(*blocks)\n\n    def forward(self, x):\n        assert type(x) is tuple and len(x)==3\n        #branch1\n        res1 = x[0]\n        out1 = self.branch1(x[0])\n        out1 += res1\n\n        #branch2\n        res2 = x[1]\n        out2 = self.branch2(x[1])\n        out2 += res2\n\n        #branch3\n        res3 = x[2]\n        out3 = self.branch3(x[2])\n        out3 += res3\n\n        return (out1,out2,out3)\n\nclass BFN1(nn.Module):\n    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):\n        super(BFN1, self).__init__()\n\n        branch1=[]\n        branch1.append(self._make_blocks(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        branch1.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))\n        self.branch1 = nn.Sequential(*branch1)\n        #self.act=nn.ReLU(True)\n\n\n    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        blocks = []\n        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \\\n            for _ in range(n_blocks)]\n        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))\n        \n        return nn.Sequential(*blocks)\n\n    def forward(self, x):\n        #branch1\n        res1 = x\n        out1 = self.branch1(x)\n        out1 += res1\n\n        return out1\n\nclass BFN2(nn.Module):\n    def __init__(self, num_channels, kernel_size, reduction, n_blocks, block):\n        super(BFN2, self).__init__()\n\n        branch1=[]\n        branch1.append(self._make_blocks(num_channels[0], num_channels[0], kernel_size, reduction, n_blocks, block))\n        branch1.append(nn.Conv2d(num_channels[0], num_channels[0], kernel_size, stride=1, padding=1, bias=True))\n        branch2=[]\n        branch2.append(self._make_blocks(num_channels[1], num_channels[1], kernel_size, reduction, n_blocks, block))\n        branch2.append(nn.Conv2d(num_channels[1], num_channels[1], kernel_size, stride=1, padding=1, bias=True))\n        self.branch1 = nn.Sequential(*branch1)\n        self.branch2 = nn.Sequential(*branch2)\n        #self.act=nn.ReLU(True)\n\n\n    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        blocks = []\n        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \\\n            for _ in range(n_blocks)]\n        blocks.append(nn.Conv2d(num_channels, num_channels, kernel_size, stride=1, padding=1, bias=True))\n        \n        return nn.Sequential(*blocks)\n\n    def forward(self, x):\n        assert type(x) is tuple and len(x)==2\n        #branch1\n        res1 = x[0]\n        out1 = self.branch1(x[0])\n        out1 += res1\n\n        #branch2\n        res2 = x[1]\n        out2 = self.branch2(x[1])\n        out2 += res2\n\n        return (out1,out2)\n\nclass EoctResBlock(nn.Module):\n    expansion = 1\n\n    def __init__(self, in_channels, num_channels, stride=1, downsample=None, res_scale=1, **kwargs):\n        super(EoctResBlock, self).__init__()\n        self.num_channels = num_channels # (64,64,64)\n        self.stride = stride\n        self.downsample = downsample\n        self.res_scale = res_scale\n        self.conv1 = ops.EoctConv(in_channels, num_channels, stride=stride)\n        self.conv2 = ops.EoctConv(num_channels, num_channels)\n\n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        #out = ops.bn(out, self.num_channels)\n        out = ops.relu(out)\n\n        out = self.conv2(out)\n        #out = ops.bn(out, self.num_channels)\n        \n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        #out = out * self.res_scale + residual\n        out = ops.tupleSum(out,residual)\n        #pdb.set_trace()\n        out = ops.relu(out)\n\n        return out\n\nclass EoctBottleneck(nn.Module):\n    def __init__(self, in_channels, num_channels, stride=1, downsample=None, res_scale=1, **kwargs):\n        super(EoctBottleneck, self).__init__()\n        self.num_channels = num_channels\n        self.stride = stride\n        self.downsample = downsample\n        self.res_scale = res_scale\n        expand = 6\n        linear = 0.8\n        self.conv1 = ops.EoctConv(in_channels, ops.tupleMultiply(num_channels,expand), kernel_size=1, padding=1//2)\n        #self.bn1 = nn.BatchNorm2d(num_channels*expand, momentum=BN_MOMENTUM)\n        self.conv2 = ops.EoctConv(ops.tupleMultiply(num_channels,expand), int(ops.tupleMultiply(num_channels,linear)), kernel_size=1, padding=1//2)\n        self.conv3 = ops.EoctConv(int(ops.tupleMultiply(num_channels,linear)), num_channels, kernel_size=3, padding=kernel_size//2)\n    \n    def forward(self, x):\n        residual = x\n\n        out = self.conv1(x)\n        #out = ops.bn(out, self.num_channels)\n        out = ops.relu(out)\n\n        out = self.conv2(out)\n        #out = ops.bn(out, self.num_channels)\n        \n        out = self.conv3(out)\n        #out = ops.bn(out, self.num_channels)\n        \n        if self.downsample is not None:\n            residual = self.downsample(x)\n\n        #out = out * self.res_scale + residual\n        out = ops.tupleSum(out,residual)\n        out = ops.relu(out)\n\n        return out\n        \n\nclass CALayer(nn.Module):\n    def __init__(self, in_channels, num_channels, reduction=16):\n        super(CALayer, self).__init__()\n        \n        # feature channel downscale and upscale --> channel weight\n        self.conv1 = ops.EoctConv(in_channels, num_channels // reduction, 1, padding=0, bias=True),\n        self.conv2 = ops.EoctConv(num_channels // reduction, num_channels, 1, padding=0, bias=True),\n\n\n    def forward(self, x):\n    \n        out = ops.avg_pool2d(x)\n        \n        out = self.conv1(out)\n        out = ops.relu(out)\n        out = self.conv2(out)\n        out = ops.sigmoid(out)\n        \n        return x * out\n\nclass CAEoctResBlock(nn.Module):\n    def __init__(self, in_channels, num_channels, reduction, bias=True, res_scale=1, **kwargs):\n        super(CAEoctResBlock, self).__init__()\n        self.num_channels = num_channels # [64,64,64,64]\n        self.res_scale = res_scale\n        self.conv1 = ops.EoctConv(in_channels, num_channels, stride=stride)\n        self.conv2 = ops.EoctConv(num_channels, num_channels)\n        self.caLayer = CAEctBlock(num_channels, num_channels, reduction)\n        \n    def forward(self, x):\n        res = x\n        \n        out = self.conv1(x)\n        out = ops.relu(out)\n        out = self.conv2(out)\n        \n        out = self.caLayer(out)\n        \n        out = ops.tupleSum(out,res)\n        #out = out * self.res_scale + res\n        out = ops.relu(out)\n        \n        \n        return out\n\nblocks_dict = {\n    'BASIC':ResBlock,\n    'EctBASIC': EoctResBlock,\n    'EctBOTTLENECK': EoctBottleneck,\n    'CAEctBASIC':CAEoctResBlock\n}\n\ndef make_model(args, parent=False):\n    return MatrixModelG2(args)\n\nclass MatrixModel(nn.Module):\n    def __init__(self, args):\n        super(MatrixModel, self).__init__()\n        \n        n_groups = args.n_resgroups\n        n_blocks = args.n_resblocks\n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = EoctResBlock\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = ops.EoctConv(3, 64)\n        \n        modules_body1 = []\n        modules_body1.append(self._make_blocks(64, 64, kernel_size, reduction, n_blocks, block))\n        modules_body1.append(ops.EoctConv(64, (64,64), kernel_size))\n        \n        modules_body2 = []\n        modules_body2.append(self._make_blocks((64,64), (64,64), kernel_size, reduction, n_blocks, block))\n        modules_body2.append(ops.EoctConv((64,64), num_channels, kernel_size))\n        \n        modules_body3 = []\n        modules_body3.append(self._make_blocks(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body3.append(ops.EoctConv(num_channels, 64, kernel_size))\n        \n        modules_tail = [\n            ops._UpsampleBlock(num_channels[0], scale=scale),\n            nn.Conv2d(num_channels[0], 3, kernel_size, 1, 1)]\n        \n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n        \n    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        blocks = []\n        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \\\n            for _ in range(n_blocks)]\n        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        \n        return nn.Sequential(*blocks)\n        \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n\n        res = x\n        x = self.body1(x)\n        x = self.body2(x)\n        x = self.body3(x)\n        x += res\n        #pdb.set_trace()\n\n        out = self.tail(x)\n        out = self.add_mean(out)\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\nclass RERB(nn.Module):\n    def __init__(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        super(RERB, self).__init__()\n\n        blocks = []\n        blocks.append(self._make_blocks(in_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        self.body = nn.Sequential(*blocks)\n\n    def _make_blocks(self, in_channels, num_channels, kernel_size, reduction, n_blocks, block):\n        blocks = []\n        blocks = [block(in_channels=in_channels, num_channels=num_channels, reduction=reduction) \\\n            for _ in range(n_blocks)]\n        blocks.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        \n        return nn.Sequential(*blocks)\n\n    def forward(self, x):\n        res = x\n        x = self.body(x)\n        x = ops.tupleSum(x,res)\n        x = ops.relu(x)\n\n        return x\n\n\nclass MatrixModelB(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelB, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        self.fusion_conv1 = ops.EoctConv(num_channels, num_channels, kernel_size)\n        self.fusion_conv2 = ops.EoctConv(num_channels, num_channels, kernel_size)\n        self.fusion_conv3 = ops.EoctConv(num_channels, num_channels, kernel_size)\n        self.conv_last = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n        L1_fea = x[0]\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n        L2_fea = x[1]\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x)\n        L3_fea = x[2]\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        x = (L1_fea, L2_fea, L3_fea)\n        res1 = x\n        x = self.fusion_conv1(x)\n        x = ops.tupleSum(x,res1)\n        res2 = x\n        x = self.fusion_conv2(x)\n        x = ops.tupleSum(x,res2)\n        res3 = x\n        x = self.fusion_conv3(x)\n        x = ops.tupleSum(x,res3)\n        out = self.conv_last(x)\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\nclass PDF(nn.Module):\n    ''' Alignment module using Pyramid, Deformable convolution and Fusion.\n    with 3 pyramid levels.\n    Bottom-Up.\n    '''\n\n    def __init__(self, nf=64, groups=8):\n        super(PDF, self).__init__()\n        # L1: level 1, original spatial size\n        #self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L1_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        # L2: level 2, 1/2 spatial size\n        self.L2_offset_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)  # concat for diff\n        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)\n        self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n        # L3: level 3, 1/4 spatial size\n        self.L3_offset_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)  # concat for diff\n        self.L3_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L3_offset_conv3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)\n        self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        self.L3_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=False)\n        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)\n        self.conv_last = nn.Conv2d(nf * 3, nf, 3, 1, 1, bias=True)\n\n    def forward(self, nbr_fea_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L1\n        L1_offset = nbr_fea_l[0]\n        #L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L1_offset = self.lrelu(self.L1_offset_conv2(L1_offset))\n        L1_fea = self.lrelu(self.L1_dcnpack([nbr_fea_l[0], L1_offset]))\n        L1_f = L1_fea\n        # L2\n        L2_offset = nbr_fea_l[1]\n        L1_offset = self.lrelu(self.L2_offset_conv1(L1_offset))\n        #L1_offset = F.interpolate(L1_offset, scale_factor=1/2, mode='bilinear', align_corners=False)\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L1_offset * 2], dim=1)))\n        #L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])\n        L1_fea = self.lrelu(self.L2_offset_conv3(L1_fea))\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L1_fea], dim=1)))\n        L2_f = L2_fea\n        # L3\n        L3_offset = nbr_fea_l[2]\n        #L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L2_offset = self.L3_offset_conv1(L2_offset)\n        L3_offset = self.lrelu(self.L3_offset_conv2(torch.cat([L3_offset, L2_offset * 2], dim=1)))\n        #L3_offset = self.lrelu(self.L3_offset_conv3(L3_offset))\n        L3_fea = self.L3_dcnpack([nbr_fea_l[2], L3_offset])\n        L2_fea = self.lrelu(self.L3_offset_conv3(L2_fea))\n        L3_fea = self.L3_fea_conv(torch.cat([L3_fea, L2_fea], dim=1))\n        # Fusion\n        L3_fea = self.upsample2(L3_fea)\n        L2_f = self.upsample(L2_f)\n        L_fea = torch.cat([torch.cat([L1_f, L2_f], dim=1),L3_fea],dim=1)\n        L_fea = self.lrelu(self.conv_last(L_fea))\n        return L_fea\n\nclass PD(nn.Module):\n    ''' module using Pyramid, Deformable convolution\n    with 3 pyramid levels.\n    Top-down.\n    '''\n\n    def __init__(self, nf=64, groups=8):\n        super(PD, self).__init__()\n        # L3: level 3, 1/4 spatial size\n        #self.L3_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        # L2: level 2, 1/2 spatial size\n        #self.L2_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n        # L1: level 1, original spatial size\n        #self.L1_offset_conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)  # concat for diff\n        self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for offset\n        self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n        self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,\n                              extra_offset_mask=True)\n        self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for fea\n        # Cascading DCN\n        #self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)  # concat for diff\n        #self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)\n\n        #self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,extra_offset_mask=True)\n\n        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=False)\n        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n        #self.upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)\n        #self.conv_last = nn.Conv2d(nf * 3, nf, 3, 1, 1, bias=True)\n\n    def forward(self, nbr_fea_l):\n        '''align other neighboring frames to the reference frame in the feature level\n        nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features\n        '''\n        # L3\n        L3_offset = nbr_fea_l[2]\n        #L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))\n        L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))\n        L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))\n        L3_f = L3_fea\n        # L2\n        L2_offset = nbr_fea_l[1]\n        #L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))\n        L3_offset = self.upsample(L3_offset)\n        L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))\n        L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))\n        L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])\n        L3_fea = self.upsample(L3_fea)\n        #pdb.set_trace()\n        L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))\n        L2_f = L2_fea\n        # L1\n        L1_offset = nbr_fea_l[0]\n        #L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))\n        L2_offset = self.upsample(L2_offset)\n        L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))\n        L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))\n        L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])\n        L2_fea = self.upsample(L2_fea)\n        L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))\n        # Cascading\n        #offset = L1_fea\n        #offset = self.lrelu(self.cas_offset_conv1(offset))\n        #offset = self.lrelu(self.cas_offset_conv2(offset))\n        #L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))\n        \n        #L3_f = self.upsample2(L3_f)\n        #L2_f = self.upsample(L2_f)\n        #L_fea = torch.cat([torch.cat([L1_fea, L2_f], dim=1),L3_f],dim=1)\n        #L_fea = self.lrelu(self.conv_last(L_fea))\n\n        return (L1_fea, L2_f, L3_f)\n\nclass MatrixModelC(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelC, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n        '''\n        \n        self.pd = PD()\n        self.pdf = PDF()\n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        self.tail1 = nn.Sequential(*modules_tail1)\n        \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n        L1_fea = x[0]\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n        L2_fea = x[1]\n\n        #stage3\n        x = self.stage3(x)\n        x = self.stage3_conv(x)\n        L3_fea = x[2]\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        #pdf-Two_way\n        x = (L1_fea, L2_fea, L3_fea)\n        x = self.pd(x)\n        out = self.pdf(x)\n        #long skip\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\n\nclass MatrixModelD(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelD, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        \n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n\n        #stage3\n        x = self.stage3(x)\n        x = self.stage3_conv(x)\n        out = x[0]\n        #out3 = x[2]\n        #pdb.set_trace()\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail3(out3)\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\n\nclass MatrixModelE(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelE, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        \n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x)\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\n\nclass MatrixModelF(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelF, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        \n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x)\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\n\nclass MatrixModelG(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelG, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        self.last_conv = nn.Conv2d(64*3, 64, kernel_size, 1, 1)\n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        out1 = x[0]\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n        out2 = x[0]\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x)\n        out2 = torch.cat([out1,out2], dim=1)\n        out = torch.cat([out2,out], dim=1)\n        out = self.last_conv(out)\n\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\nclass MatrixModelG2(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelG2, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        self.da = DAM_Module(64)\n        #self.da_conv = nn.Conv2d(64*3, 64, 3, 1, 1)\n        self.last_conv = nn.Conv2d(64*3, 64, kernel_size, 1, 1)\n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        out1 = x[0].unsqueeze(1)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n        out2 = x[0].unsqueeze(1)\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x).unsqueeze(1)\n        out2 = torch.cat([out1,out2], dim=1)\n        out = torch.cat([out2,out], dim=1)\n\n        out = self.da(out)\n        out = self.last_conv(out)\n\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\nclass MatrixModelF2(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelF2, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        inter_channels = 64 #// 4\n        self.conv5a = nn.Sequential(nn.Conv2d(64, inter_channels, 3, padding=1, bias=False),\n                                   nn.BatchNorm2d(inter_channels),\n                                   nn.ReLU())\n        \n        self.conv5c = nn.Sequential(nn.Conv2d(64, inter_channels, 3, padding=1, bias=False),\n                                   nn.BatchNorm2d(inter_channels),\n                                   nn.ReLU())\n\n        self.pa = PAM_Module(inter_channels)\n        self.ca = CAM_Module(inter_channels)\n        self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),\n                                   nn.BatchNorm2d(inter_channels),\n                                   nn.ReLU())\n        self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),\n                                   nn.BatchNorm2d(inter_channels),\n                                   nn.ReLU())\n\n        #self.conv6 = nn.Sequential(nn.Conv2d(192, 64, 1))\n        #self.da_conv = nn.Conv2d(64*3, 64, 3, 1, 1)\n        self.last_conv = nn.Conv2d(64*3, 64, kernel_size, 1, 1)\n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n\n        #stage3\n        x = self.stage3(x)\n        out1 = self.stage3_conv(x)\n\n        feat1 = self.conv5a(out1)\n        pa_feat = self.pa(feat1)\n        pa_conv = self.conv51(pa_feat)\n        #pa_output = self.conv6(pa_conv)\n\n        feat2 = self.conv5c(out1)\n        ca_feat = self.ca(feat2)\n        ca_conv = self.conv52(ca_feat)\n        #ca_output = self.conv7(ca_conv)\n\n        feat_sum = torch.cat([pa_conv,ca_conv],dim=1)\n        paca_output = torch.cat([feat_sum,out1],dim=1)\n\n        out = self.last_conv(paca_output)\n\n        \n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))\n\nclass PAM_Module(nn.Module):\n    \"\"\" Position attention module\"\"\"\n    #Ref from SAGAN\n    def __init__(self, in_dim):\n        super(PAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)\n        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)\n        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)\n        self.gamma = nn.Parameter(torch.zeros(1))\n\n        self.softmax = nn.Softmax(dim=-1)\n    def forward(self, x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X (HxW) X (HxW)\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)\n        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)\n        energy = torch.bmm(proj_query, proj_key)\n        attention = self.softmax(energy)\n        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)\n\n        out = torch.bmm(proj_value, attention.permute(0, 2, 1))\n        out = out.view(m_batchsize, C, height, width)\n\n        out = self.gamma*out + x\n        return out\n\n\nclass CAM_Module(nn.Module):\n    \"\"\" Channel attention module\"\"\"\n    def __init__(self, in_dim):\n        super(CAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X C X C\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, C, -1)\n        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, C, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, C, height, width)\n\n        out = self.gamma*out + x\n        return out\n\nclass GAM_Module(nn.Module):\n    \"\"\" Global attention module\"\"\"\n    def __init__(self, in_dim):\n        super(GAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X (C*H*W) X (C*H*W)\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, -1).unsqueeze(-1)\n        proj_key = x.view(m_batchsize, -1).unsqueeze(-1).permute(0, 2, 1)\n        #pdb.set_trace()\n        energy = torch.bmm(proj_query, proj_key)\n        #energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy)\n        proj_value = x.view(m_batchsize, -1).unsqueeze(-1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, C, height, width)\n\n        out = self.gamma*out + x\n        return out\n\nclass DAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self, in_dim):\n        super(DAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass MatrixModelH(nn.Module):\n    def __init__(self, args):\n        super(MatrixModelH, self).__init__()\n        \n        num_channels = (64, 64, 64)\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale\n        block = blocks_dict[args.block]\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        self.add_mean = ops.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n        \n        self.first_conv = nn.Conv2d(3, 64, kernel_size, stride=1, padding=1, bias=True)\n        \n        modules_stage1 = []\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        modules_stage1.append(BFN1(64, kernel_size, reduction, 5, block))\n        self.stage1 = nn.Sequential(*modules_stage1)\n        self.stage1_conv = ops.EoctConv(64, (64,64), kernel_size)\n\n        modules_stage2 = []\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        modules_stage2.append(BFN2((64,64), kernel_size, reduction, 5, block))\n        self.stage2 = nn.Sequential(*modules_stage2)\n        self.stage2_conv = ops.EoctConv((64,64), num_channels, kernel_size)\n\n        modules_stage3 = []\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        modules_stage3.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage3 = nn.Sequential(*modules_stage3)\n        self.stage3_conv = ops.EoctConv(num_channels, 64, kernel_size)\n        \n        self.pa = PAM_Module(64)\n        self.pa_conv = nn.Conv2d(64, 64, 3, 1, 1)\n        self.ca = CAM_Module(64)\n        self.ca_conv = nn.Conv2d(64, 64, 3, 1, 1)\n        '''\n        modules_stage4 = []\n        modules_stage4.append(BFN(num_channels, kernel_size, reduction, 5, block))\n        self.stage4 = nn.Sequential(*modules_stage4)\n        self.stage4_conv = ops.EoctConv(num_channels, num_channels, kernel_size)\n\n        \n        modules_body = []\n        for i in range(n_groups):\n            modules_body.append(RERB(num_channels, num_channels, kernel_size, reduction, n_blocks, block))\n        modules_body.append(ops.EoctConv(num_channels, num_channels, kernel_size))\n        '''\n        \n        modules_tail1 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail1 = nn.Sequential(*modules_tail1)\n        '''\n        modules_tail2 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail2 = nn.Sequential(*modules_tail2)\n        \n        modules_tail3 = [\n            ops._UpsampleBlock(64, scale=scale),\n            nn.Conv2d(64, 3, kernel_size, 1, 1)]\n        \n        #self.body = nn.Sequential(*modules_body)\n        self.tail3 = nn.Sequential(*modules_tail3)\n        '''\n              \n    def forward(self, x):\n        \n        x = self.sub_mean(x)\n        x = self.first_conv(x)\n        residual = x\n        #pdb.set_trace()\n\n        #stage1\n        x = self.stage1(x)\n        x = self.stage1_conv(x)\n        #pdb.set_trace()\n\n        #stage2\n        x = self.stage2(x)\n        x = self.stage2_conv(x)\n\n        #stage3\n        x = self.stage3(x)\n        out = self.stage3_conv(x)\n        \n        #atten\n        pa_out = self.pa(out)\n        pa_out = self.pa_conv(pa_out)\n        ca_out = self.ca(out)\n        ca_out = self.ca_conv(ca_out)\n        out = pa_out + ca_out\n        #stage4\n        #x = self.stage4(x)\n        #x = self.stage4_conv(x)\n\n        out += residual\n\n        out = self.tail1(out)\n        out = self.add_mean(out)\n\n        #out2 = self.tail1(x[1])\n        #out2 = self.add_mean(out2)\n\n        #out3 = self.tail2(x[2])\n        #out3 = self.add_mean(out3)\n        #pdb.set_trace()\n        \n        return out\n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/mdsr.py",
    "content": "from model import common\n\nimport torch.nn as nn\n\nurl = {\n    'r16f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr_baseline-a00cab12.pt',\n    'r80f64': 'https://cv.snu.ac.kr/research/EDSR/models/mdsr-4a78bedf.pt'\n}\n\ndef make_model(args, parent=False):\n    return MDSR(args)\n\nclass MDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(MDSR, self).__init__()\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        act = nn.ReLU(True)\n        self.scale_idx = 0\n        self.url = url['r{}f{}'.format(n_resblocks, n_feats)]\n        self.sub_mean = common.MeanShift(args.rgb_range)\n        self.add_mean = common.MeanShift(args.rgb_range, sign=1)\n\n        m_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        self.pre_process = nn.ModuleList([\n            nn.Sequential(\n                common.ResBlock(conv, n_feats, 5, act=act),\n                common.ResBlock(conv, n_feats, 5, act=act)\n            ) for _ in args.scale\n        ])\n\n        m_body = [\n            common.ResBlock(\n                conv, n_feats, kernel_size, act=act\n            ) for _ in range(n_resblocks)\n        ]\n        m_body.append(conv(n_feats, n_feats, kernel_size))\n\n        self.upsample = nn.ModuleList([\n            common.Upsampler(conv, s, n_feats, act=False) for s in args.scale\n        ])\n\n        m_tail = [conv(n_feats, args.n_colors, kernel_size)]\n\n        self.head = nn.Sequential(*m_head)\n        self.body = nn.Sequential(*m_body)\n        self.tail = nn.Sequential(*m_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        x = self.pre_process[self.scale_idx](x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.upsample[self.scale_idx](res)\n        x = self.tail(x)\n        x = self.add_mean(x)\n\n        return x\n\n    def set_scale(self, scale_idx):\n        self.scale_idx = scale_idx\n\n"
  },
  {
    "path": "src/model/ops.py",
    "content": "'''EoctConv'''\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\nimport torch\r\nimport numpy as np\r\nimport math\r\nimport pdb\r\n\r\nBN_MOMENTUM = 0.1\r\nclass EoctConv(nn.Module):\r\n    def __init__(self, in_channels, num_channels, kernel_size=3, stride=1, padding=1, bias=True, name=None):\r\n        super(EoctConv, self).__init__()\r\n        self.stride = stride\r\n        #input channels\r\n        if type(in_channels) is tuple and len(in_channels)==3:\r\n            in_h, in_l ,in_ll= in_channels\r\n        elif type(in_channels) is tuple and len(in_channels)==2:\r\n            in_h, in_l = in_channels\r\n            in_ll = None\r\n        else:\r\n            in_h, in_l ,in_ll= (in_channels, None, None)\r\n        #output channels\r\n        if type(num_channels) is tuple and len(num_channels)==3:\r\n            num_high, num_low, num_ll = num_channels\r\n        elif type(num_channels) is tuple and len(num_channels)==2:\r\n        #pdb.set_trace()\r\n            num_high, num_low = num_channels\r\n            num_ll = 0\r\n        else:\r\n            num_high, num_low, num_ll = (num_channels, 0, 0)\r\n        self.num_high = num_high\r\n        self.num_low = num_low\r\n        self.num_ll = num_ll\r\n        if in_h is not None:\r\n            self.conv2d1 = nn.Conv2d(in_h, num_high, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_high > 0 else None\r\n            self.conv2d2 = nn.Conv2d(in_h, num_low, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_low > 0 else None\r\n            self.conv2d3 = nn.Conv2d(in_h, num_ll, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_ll > 0 else None\r\n        if in_l is not None:\r\n            self.conv2d4 = nn.Conv2d(in_l, num_low, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_low > 0 else None\r\n            self.conv2d5 = nn.Conv2d(in_l, num_high, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_high > 0 else None\r\n            self.conv2d6 = nn.Conv2d(in_l, num_ll, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_ll > 0 else None\r\n        if in_ll is not None:\r\n            self.conv2d7 = nn.Conv2d(in_ll, num_ll, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_ll > 0 else None\r\n            self.conv2d8 = nn.Conv2d(in_ll, num_high, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_high > 0 else None\r\n            self.conv2d9 = nn.Conv2d(in_ll, num_low, kernel_size=3, stride=1, padding=1, bias=bias) if self.num_low > 0 else None\r\n        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')\r\n        self.upsample2 = nn.Upsample(scale_factor=4, mode='nearest')\r\n        self.pooling1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)\r\n        self.pooling2 = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)\r\n        for m in self.modules():\r\n            if isinstance(m, nn.Conv2d):\r\n                n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels\r\n                m.weight.data.normal_(0, math.sqrt(2. / n))\r\n                nn.init.constant(m.bias,0)\r\n    def forward(self, data):\r\n        #pdb.set_trace()\r\n        stride = self.stride\r\n        \r\n        #input channels\r\n        if type(data) is tuple and len(data)==3:\r\n            data_h, data_l ,data_ll= data\r\n        elif type(data) is tuple and len(data)==2:\r\n            data_h, data_l = data\r\n            data_ll = None\r\n        else:\r\n            data_h, data_l ,data_ll= (data, None, None)\r\n        data_h2l, data_h2h, data_h2ll, data_l2l, data_l2h, data_l2ll,data_ll2ll, data_ll2h, data_ll2l= None, None, None, None, None, None, None, None, None\r\n        \r\n        \r\n        if data_h is not None:\r\n            # High -> High\r\n            data_h = self.pooling1(data_h) if stride == 2 else data_h\r\n            data_h2h = self.conv2d1(data_h) if self.num_high > 0 else None\r\n            # High -> Low\r\n            data_h2l = self.pooling1(data_h) if (self.num_low > 0) else data_h\r\n            data_h2l = self.conv2d2(data_h2l) if self.num_low > 0 else None\r\n            # High -> Lower\r\n            data_h2ll = self.pooling2(data_h) if (self.num_ll > 0) else data_h\r\n            data_h2ll = self.conv2d3(data_h2ll) if self.num_ll > 0 else None\r\n            \r\n        \r\n        '''processing low frequency group'''\r\n        if data_l is not None:\r\n            # Low -> Low\r\n            data_l2l = self.pooling1(data_l) if (self.num_low > 0 and stride == 2) else data_l\r\n            data_l2l = self.conv2d4(data_l2l) if self.num_low > 0 else None\r\n            # Low -> High\r\n            data_l2h = self.conv2d5(data_l) if self.num_high > 0 else data_l\r\n            data_l2h = self.upsample1(data_l2h) if (self.num_high > 0 and stride == 1) else None\r\n            #Low -> Lower\r\n            data_l2ll = self.pooling1(data_l) if (self.num_ll > 0) else data_l\r\n            data_l2ll = self.conv2d6(data_l2ll) if self.num_ll > 0 else None\r\n    \r\n        '''processing lower frequency group'''\r\n        if data_ll is not None:\r\n            # Lower -> Lower\r\n            data_ll2ll = self.pooling1(data_ll) if (self.num_ll > 0 and stride == 2) else data_ll\r\n            data_ll2ll = self.conv2d7(data_ll2ll) if self.num_ll > 0 else None\r\n            # Lower -> High\r\n            data_ll2h = self.conv2d8(data_ll) if self.num_high > 0 else data_ll\r\n            data_ll2h = self.upsample2(data_ll2h) if (self.num_high > 0 and stride == 1) else None\r\n            #data_ll2h = upsample3(data_ll2h) if (num_high > 0 and stride == 1) else None\r\n            #Lower -> Low\r\n            data_ll2l = self.conv2d9(data_ll) if self.num_low > 0 else data_ll\r\n            data_ll2l = self.upsample1(data_ll2l) if (self.num_low > 0 and stride == 1) else None\r\n            \r\n        '''you can force to disable the interaction paths'''\r\n        # data_h2l = None if (data_h2h is not None) and (data_l2l is not None) else data_h2l\r\n        # data_l2h = None if (data_h2h is not None) and (data_l2l is not None) else data_l2h\r\n\r\n        #output = ElementWiseSum(*[(data_h2h, data_h2l, data_h2ll), (data_l2h, data_l2l, data_l2ll), (data_ll2h, data_ll2l, data_ll2ll)], name=name)\r\n        #pdb.set_trace()\r\n        output = (dataSum(dataSum(data_h2h, data_l2h), data_ll2h), dataSum(dataSum(data_h2l, data_l2l), data_ll2l), dataSum(dataSum(data_h2ll, data_l2ll) ,data_ll2ll))\r\n        #output = torch.from_numpy(np.array(output))\r\n        # squeeze output (to be backward compatible)\r\n        if output[2] is None:\r\n            if output[1] is None:\r\n                return output[0]\r\n            else:\r\n                return output[0:2]\r\n        elif output[1] is None:\r\n            return output[0::2]\r\n        else:\r\n            return output\r\n        \r\ndef relu(data):\r\n    relu = nn.ReLU(inplace=True)\r\n    if type(data) is tuple and len(data)==3:\r\n        out = (relu(data[0]), relu(data[1]), relu(data[2]))\r\n        return out\r\n        \r\n    elif type(data) is tuple and len(data)==2:\r\n        if data[0] is None:\r\n            out = (relu(data[1]), relu(data[2]))\r\n            return out\r\n        elif data[1] is None:\r\n            out = (relu(data[0]), relu(data[2]))\r\n            return out\r\n        else:\r\n            out = (relu(data[0]), relu(data[1]))\r\n            return out\r\n    else:\r\n        out = relu(data)\r\n        return out\r\n        \r\ndef sigmoid(data):\r\n    if type(data) is tuple and len(data)==3:\r\n        out = (F.sigmoid(data[0]), F.sigmoid(data[1]), F.sigmoid(data[2]))\r\n        return out\r\n        \r\n    elif type(data) is tuple and len(data)==2:\r\n        if data[0] is None:\r\n            out = (F.sigmoid(data[1]), F.sigmoid(data[2]))\r\n            return out\r\n        elif data[1] is None:\r\n            out = (F.sigmoid(data[0]), F.sigmoid(data[2]))\r\n            return out\r\n        else:\r\n            out = (F.sigmoid(data[0]), F.sigmoid(data[1]))\r\n            return out\r\n    elif type(data) is Tensor:\r\n        out = F.sigmoid(data)\r\n        return out\r\n\r\ndef bn(data, num_channels):\r\n    if type(data) is tuple and len(data)==3:\r\n        bn1 = nn.BatchNorm2d(num_channels[0], momentum=BN_MOMENTUM)\r\n        bn2 = nn.BatchNorm2d(num_channels[1], momentum=BN_MOMENTUM)\r\n        bn3 = nn.BatchNorm2d(num_channels[2], momentum=BN_MOMENTUM)\r\n        out = (bn1(data[0]), bn2(data[1]), bn3(data[2]))\r\n        return out\r\n    elif type(data) is tuple and len(data)==2:\r\n        bn1 = nn.BatchNorm2d(num_channels[0], momentum=BN_MOMENTUM)\r\n        bn2 = nn.BatchNorm2d(num_channels[1], momentum=BN_MOMENTUM)\r\n        out = (bn1(data[0]), bn2(data[1]))\r\n        return out\r\n    elif type(data) is Tensor:\r\n        bn1 = nn.BatchNorm2d(num_channels, momentum=BN_MOMENTUM)\r\n        out = bn1(data)\r\n        return out\r\n    \r\ndef max_pool2d(data, l=(2,2)):\r\n    if type(data) is tuple and len(data)==3:\r\n        out = (F.max_pool2d(data[0], l), F.max_pool2d(data[1], l), F.max_pool2d(data[2], l))\r\n        return out\r\n        \r\n    elif type(data) is tuple and len(data)==2:\r\n        if data[0] is None:\r\n            out = (F.max_pool2d(data[1], l), F.max_pool2d(data[2], l))\r\n            return out\r\n        elif data[1] is None:\r\n            out = (F.max_pool2d(data[0], l), F.max_pool2d(data[2], l))\r\n            return out\r\n        else:\r\n            out = (F.max_pool2d(data[0], l), F.max_pool2d(data[1], l))\r\n            return out\r\n    elif type(data) is Tensor:\r\n        out = F.max_pool2d(data, l)\r\n        return out\r\n        \r\ndef avg_pool2d(data):\r\n    avg_pool = nn.AdaptiveAvgPool2d(1)\r\n    if type(data) is tuple and len(data)==3:\r\n        out = (avg_pool(data[0]), avg_pool(data[1]), avg_pool(data[2]))\r\n        return out\r\n        \r\n    elif type(data) is tuple and len(data)==2:\r\n        if data[0] is None:\r\n            out = (avg_pool(data[1]), avg_pool(data[2]))\r\n            return out\r\n        elif data[1] is None:\r\n            out = (avg_pool(data[0]), avg_pool(data[2]))\r\n            return out\r\n        else:\r\n            out = (avg_pool(data[0]), avg_pool(data[1]))\r\n            return out\r\n    elif type(data) is Tensor:\r\n        out = avg_pool(data)\r\n        return out\r\n\r\ndef dropout(data, l):\r\n    Dropout = nn.Dropout(l)\r\n    if type(data) is tuple and len(data)==3:\r\n        out = (Dropout(data[0]), Dropout(data[1]), Dropout(data[2]))\r\n        return out\r\n        \r\n    elif type(data) is tuple and len(data)==2:\r\n        if data[0] is None:\r\n            out = (Dropout(data[1]), Dropout(data[2]))\r\n            return out\r\n        elif data[1] is None:\r\n            out = (Dropout(data[0]), Dropout(data[2]))\r\n            return out\r\n        else:\r\n            out = (Dropout(data[0]), Dropout(data[1]))\r\n            return out\r\n    elif type(data) is Tensor:\r\n        out = Dropout(data)\r\n        return out\r\n        \r\ndef dataSum(a, b):\r\n    if a is None:\r\n        return b\r\n    elif b is None:\r\n        return a\r\n    else:\r\n        assert a.size()==b.size()\r\n        return a+b\r\n\r\n\r\n\r\ndef tupleSum(a,b):\r\n    out = (a[0]+b[0],a[1]+b[1],a[2]+b[2])\r\n    return(out)\r\n        \r\nclass MeanShift(nn.Conv2d):\r\n    def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):\r\n        super(MeanShift, self).__init__(3, 3, kernel_size=1)\r\n        std = torch.Tensor(rgb_std)\r\n        self.weight.data = torch.eye(3).view(3, 3, 1, 1)\r\n        self.weight.data.div_(std.view(3, 1, 1, 1))\r\n        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)\r\n        self.bias.data.div_(std)\r\n        self.requires_grad = False\r\n        \r\nclass _UpsampleBlock(nn.Module):\r\n    def __init__(self, \r\n                 n_channels, scale, \r\n                 group=1):\r\n        super(_UpsampleBlock, self).__init__()\r\n        '''\r\n        modules = []\r\n        if scale == 2 or scale == 4 or scale == 8:\r\n            for _ in range(int(math.log(scale, 2))):\r\n                modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]\r\n                modules += [nn.PixelShuffle(2)]\r\n        elif scale == 3:\r\n            modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)]\r\n            modules += [nn.PixelShuffle(3)]\r\n\r\n        self.body = nn.Sequential(*modules)'''\r\n        #init_weights(self.modules)\r\n        self.conv1 = nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group)\r\n        self.conv2 = nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group)\r\n        self.relu = nn.ReLU(inplace=True)\r\n        self.pixelshuffle = nn.PixelShuffle(2)\r\n        \r\n    def forward(self, x):\r\n        #out = self.body(x)\r\n        out = self.conv1(x)\r\n        #pdb.set_trace()\r\n        out = self.relu(out)\r\n        out = self.pixelshuffle(out)\r\n\r\n        out = self.conv2(out)\r\n        out = self.relu(out)\r\n        out = self.pixelshuffle(out)\r\n        #print(out.shape)\r\n\r\n        return out\r\n\r\ndef tupleMultiply(a, b):\r\n    out=[]\r\n    assert type(b) is int\r\n    for i in range(len(a)):\r\n        out.append(a[i]*b)\r\n\r\n    return tuple(out)\r\n"
  },
  {
    "path": "src/model/rcan.py",
    "content": "from model import common\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport pdb\n\ndef make_model(args, parent=False):\n    return RCAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\n## Residual Channel Attention Block (RCAB)\n\nclass Ada_conv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, bias=True, category=2):\n        super(Ada_conv, self).__init__()\n        self.conv0 = nn.Conv2d(\n            in_channels, out_channels, 1,\n            padding=0, bias=bias)\n        self.sigmoid = nn.Sigmoid()\n        self.category = category\n        self.conv1 = nn.Conv2d(\n            in_channels, out_channels, kernel_size,\n            padding=(kernel_size//2), bias=bias)\n        self.conv2 = nn.Conv2d(\n            in_channels, out_channels, kernel_size,\n            padding=(kernel_size//2), bias=bias)\n    def forward(self, x):\n        # c = list(np.arange(0,1,1/self.category))\n        # c += 1\n        m_batchsize, C, height, width = x.size()\n        mask = self.sigmoid(self.conv0(x.permute(0,1,3,2).contiguous().view(m_batchsize,C,height,width)))\n        #mask = self.sigmoid(self.conv0(x))\n        mask = torch.where(mask<0.5, torch.full_like(mask,1),torch.full_like(mask,0))\n        #pdb.set_trace()\n        out = self.conv1(x)*mask+self.conv2(x)*(1-mask)\n        return out\n\nclass ResAda_conv(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, bias=True, category=2):\n        super(ResAda_conv, self).__init__()\n        self.conv0 = nn.Conv2d(\n            in_channels, 1, 1,\n            padding=0, bias=bias)\n        self.sigmoid = nn.Sigmoid()\n        self.category = category\n        self.conv1 = nn.Conv2d(\n            in_channels, out_channels, kernel_size,\n            padding=(kernel_size//2), bias=bias)\n        self.conv2 = nn.Conv2d(\n            in_channels, out_channels, kernel_size,\n            padding=(kernel_size//2), bias=bias)\n    def forward(self, x):\n        # c = list(np.arange(0,1,1/self.category))\n        # c += 1\n        m_batchsize, C, height, width = x.size()\n        mask = self.sigmoid(self.conv0(x))\n        mask = torch.where(mask<0.5, torch.full_like(mask,1),torch.full_like(mask,0))\n        #pdb.set_trace()\n        #mask = mask[mask<0.5].view(m_batchsize,C,height,width)\n        out = self.conv1(x)*mask+self.conv2(x)*(1-mask)\n        return out+x\n\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            #modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            modules_body.append(Ada_conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=act, res_scale=1) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        #modules_body.append(Ada_conv(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Channel Attention Network (RCAN)\nclass RCAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(RCAN, self).__init__()\n        \n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/rcan1.py",
    "content": "from model import common\n\nimport torch.nn as nn\nimport torch\nimport torch.nn.init as init\nimport pdb\n\ndef make_model(args, parent=False):\n    return RCAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\nclass Dis(nn.Module):\n    def __init__(self, loss_type='L1', batchsize=16):\n        super(Dis, self).__init__()\n        self.loss_type = loss_type\n        #self.loss = torch.zeros(B)\n        if self.loss_type == 'cos':\n            self.dot_product, self.square_sum_x, self.square_sum_y = torch.zeros(batchsize).cuda(), torch.zeros(batchsize).cuda(), torch.zeros(batchsize).cuda()\n\n\n    def forward(self, x1, x2):\n\n        if self.loss_type=='L1':\n            return self.L1Loss(x1, x2)\n\n        if self.loss_type=='L2':\n            return self.L2Loss(x1, x2)\n\n        if self.loss_type=='cos':\n            return self.cosine_similarity(x1, x2)\n\n\n    def L1Loss(self, x1, x2):\n\n        loss = torch.sum(torch.abs(x1[:]-x2[:]), dim=1)\n        return loss\n\n    def L2Loss(self, x1, x2):\n\n        loss = torch.sum((x1[:]-x2[:]).pow(2), dim=1)\n        return loss\n\n    def bit_product_sum(self, x, y):\n        return sum([item[0] * item[1] for item in zip(x, y)])\n\n\n    def cosine_similarity(self, x, y, norm=True):\n        \"\"\" 计算两个向量x和y的余弦相似度 \"\"\"\n        assert len(x) == len(y), \"len(x) != len(y)\"\n\n        # method 1\n        #res = torch.tensor([[x[i] * y[i], x[i] * x[i], y[i] * y[i]] for i in range(len(x))])\n        #cos = sum(res[:, 0]) / (torch.sqrt(sum(res[:, 1])) * torch.sqrt(sum(res[:, 2])))\n\n        # method 2\n        # cos = self.bit_product_sum(x, y) / (torch.sqrt(self.bit_product_sum(x, x)) * torch.sqrt(self.bit_product_sum(y, y)))\n\n        #method 3\n        dot_product, square_sum_x, square_sum_y = self.dot_product, self.square_sum_x, self.square_sum_y\n        #pdb.set_trace()\n        for i in range(x.size()[1]):\n            dot_product += x[:,i] * y[:,i]\n            square_sum_x += x[:,i] * x[:,i]\n            square_sum_y += y[:,i] * y[:,i]\n        cos = dot_product / (torch.sqrt(square_sum_x) * torch.sqrt(square_sum_y))\n\n        return 0.5 * cos + 0\n\nclass FullConvRes(nn.Module):\n    \"\"\" Full Receptive Field Conv2d Residual Block\"\"\"\n    def __init__(self, out_channels=64, in_channels=64, K=9):\n        super(FullConvRes, self).__init__()\n\n\n        #self.dis = Dis('cos', batchsize=1)\n        self.out_channels = out_channels\n        self.K = K\n        #self.conv = nn.Conv2d(K,K,1,1,0)\n        #self.sigmoid = nn.Sigmoid()\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.energy1 =  torch.zeros((4, 11, 11)).cuda()\n        #self.softmax  = nn.Softmax(dim=-1)\n        self.weight = nn.Parameter(\n            torch.zeros(out_channels, in_channels, K)\n        )\n        self.bias = nn.Parameter(torch.zeros(out_channels))\n        init.xavier_uniform(self.weight)\n        init.constant(self.bias, 0.1)\n        self.relu = nn.ReLU(True)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : fullconv value + input feature\n                attention: B X HW X 9\n            process:\n            reshape x > 2d\n            \n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        #energy1 = torch.zeros((m_batchsize, height*width, height*width)).cuda()\n        proj_query = x.view(m_batchsize, C, -1)\n        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_key, proj_query)\n        energy1 = torch.zeros((m_batchsize, height*width, 1)).cuda()\n        for i in range(height*width):\n            energy1.data[:,i] = torch.sqrt(energy[:,i,i]).unsqueeze(1)\n        energy2 = energy1.permute(0, 2, 1)\n        energy_new = energy/energy1.expand_as(energy)\n        energy_new = energy_new/energy2.expand_as(energy)\n        #energy_new = energy_new*energy\n        #energy = self.softmax(energy)\n        #energy_new = self.softmax(energy_new)\n\n        #energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        #pdb.set_trace()\n        energy_new = torch.sort(energy_new, dim=-1)[1].float()\n        #pdb.set_trace()\n        e = torch.chunk(energy_new, self.K, dim=-1)\n        for i in range(self.K):\n            if i == 0:\n                energy_new = e[i][:,:,0].unsqueeze(2)\n            else:\n                energy_new = torch.cat([energy_new, e[i][:,:,0].unsqueeze(2)],dim=2)\n        #energy_new = torch.stack(torch.chunk(energy_new, self.K, dim=-1)[:][:,:,0],dim=-1)\n        #energy_new = energy_new[:,:,0,:].long()\n        energy_new = energy_new.long()\n\n        ReceptiveField = torch.zeros_like(energy)\n        for b in range(m_batchsize):\n            for t in range(height*width):\n                for k in range(self.K):\n                    #pdb.set_trace()\n                    ReceptiveField.data[b,t,energy_new[b,t,k]] = 1\n        # max5 = max((1,9))  # 取top1准确率，若取top1和top9准确率改为max((1,9))\n        # max4 = max((1,4))\n        # _, ReceptiveFieldIdex1 = energy_new.topk(max5, -1, True, False)\n        # _, ReceptiveFieldIdex2 = (-1*energy_new).topk(max4, -1, True, False)\n        # ReceptiveFieldIdex = torch.cat([ReceptiveFieldIdex1,ReceptiveFieldIdex2], 2)\n        # ReceptiveFieldIdex = (self.sigmoid(self.conv(ReceptiveFieldIdex.unsqueeze(3)).squeeze(3))*height*width).int()\n\n        #score = self.softmax(score)\n\n        #x_in = x.view(m_batchsize,-1,height*width)\n        out = torch.zeros_like(proj_query).cuda()\n\n        for i in range(self.out_channels):\n            for j in range(height*width):\n                #x_in = x_in[ReceptiveField[:,j,:].unsqueeze(1).expand_as(x_in).long()]\n                #pdb.set_trace()\n                x_out = proj_query[ReceptiveField[:,j].unsqueeze(1).expand_as(proj_query)>0].view(m_batchsize,C,-1)\n                #x_out,_ = x_in.topk(max9, -1, True, False) # The shape of x_in:B X C X 9\n                x_K = torch.sum(x_out*(self.weight[i].expand_as(x_out)), dim=1)\n                out.data[:,i,j] = torch.sum(x_K, dim=1)+self.bias[i]\n        out = self.relu(out.view(m_batchsize,C,height,width))\n\n        # max9 = max((1,9))\n        # for i in range(self.out_channels):\n        #     for j in range(height*width):\n        #         #x_in = x_in[ReceptiveField[:,j,:].unsqueeze(1).expand_as(x_in).long()]\n        #         pdb.set_trace()\n        #         x1 = x_in*(ReceptiveField[:,j].unsqueeze(1).expand_as(x_in)) \n        #         x_out,_ = x1.topk(max9, -1, True, False) # The shape of x_out:B X C X 9\n        #         x_9 = torch.sum(x_out*(self.weight[i].expand_as(x_out)), dim=1)\n        #         out.data[:,i,j] = torch.sum(x_9, dim=1)+self.bias[i]\n        # out = self.relu(out.view(m_batchsize,C,height,width))\n\n        return self.gamma * out + x\n\nclass FullConvRes1(nn.Module):\n    \"\"\" Full Receptive Field Conv2d Residual Block\"\"\"\n    def __init__(self, out_channels=64, in_channels=64, kernel_size=3):\n        super(FullConvRes1, self).__init__()\n\n\n        #self.dis = Dis('cos', batchsize=1)\n        self.out_channels = out_channels\n        self.query_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)\n        self.key_conv = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1)\n        #self.value_conv = Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.energy1 =  torch.zeros((4, 11, 11)).cuda()\n        #self.softmax  = nn.Softmax(dim=-1)\n        self.weight = nn.Parameter(\n            torch.Tensor(out_channels, in_channels, kernel_size*kernel_size)\n        )\n        self.bias = nn.Parameter(torch.zeros(out_channels))\n        init.xavier_uniform(self.weight)\n        init.constant(self.bias, 0.1)\n        self.relu = nn.ReLU(True)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : fullconv value + input feature\n                attention: B X HW X 9\n            process:\n            reshape x > 2d\n            \n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        #energy1 = torch.zeros((m_batchsize, height*width, height*width)).cuda()\n        proj_query = self.query_conv(x).view(m_batchsize, C//8, -1)\n        proj_key = self.key_conv(x).view(m_batchsize, C//8, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_key, proj_query)\n        # energy1 = torch.zeros((m_batchsize, height*width, 1)).cuda()\n        # for i in range(height*width):\n        #     energy1.data[:,i] = torch.sqrt(energy[:,i,i]).unsqueeze(1)\n        # energy2 = energy1.permute(0, 2, 1)\n        # energy_new = energy/energy1.expand_as(energy)\n        # energy_new = energy_new/energy2.expand_as(energy)\n        #energy = self.softmax(energy)\n        # energy_new = energy_new\n\n        \n\n\n        #energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        #pdb.set_trace()\n        max9 = max((1,9))  # 取top1准确率，若取top1和top9准确率改为max((1,9))\n        _, ReceptiveFieldIdex = energy.topk(max9, -1, True, False)\n\n        proj_query = x.view(m_batchsize,-1,height*width)\n        out = torch.zeros_like(proj_query).cuda()\n\n        ReceptiveField = torch.zeros_like(energy)\n        #x_out = torch.zeros_like(x_in)\n        for b in range(m_batchsize):\n            for t in range(height*width):\n                for k in range(9):\n                    #pdb.set_trace()\n                    ReceptiveField.data[b,t,ReceptiveFieldIdex[b,t,k]] = 1\n\n        for i in range(self.out_channels):\n            for j in range(height*width):\n                x_out = proj_query[ReceptiveField[:,j].unsqueeze(1).expand_as(proj_query)>0].view(m_batchsize,C,-1)\n                #x_out,_ = x_in.topk(max9, -1, True, False) # The shape of x_in:B X C X 9\n                x_K = torch.sum(x_out*(self.weight[i].expand_as(x_out)), dim=1)\n                out.data[:,i,j] = torch.sum(x_K, dim=1)+self.bias[i]\n        out = self.relu(out.view(m_batchsize,C,height,width))\n        # for i in range(self.out_channels):\n        #     for j in range(height*width):\n        #         #x_in = x_in[ReceptiveField[:,j,:].unsqueeze(1).expand_as(x_in).long()]\n        #         #pdb.set_trace()\n        #         x_in = x_in*(ReceptiveField[:,j].unsqueeze(1).expand_as(x_in))\n        #         x_out,_ = x_in.topk(max9, -1, True, False) # The shape of x_in:B X C X 9\n        #         x_9 = torch.sum(x_out*(self.weight[i].expand_as(x_out)), dim=1)\n        #         out.data[:,i,j] = torch.sum(x_9, dim=1)\n        # out = self.relu(out.view(m_batchsize,C,height,width))\n\n        return self.gamma * out + x\n\nclass FullConv(nn.Module):\n    \"\"\" Full Receptive Field Conv2d Block\"\"\"\n    def __init__(self, out_channels=64, in_channels=64, kernel_size=3):\n        super(FullConv, self).__init__()\n\n\n        self.dis = Dis('cos', batchsize=16)\n        self.out_channels = out_channels\n        #self.gamma = nn.Parameter(torch.zeros(1))\n        #self.energy1 =  torch.zeros((4, 11, 11)).cuda()\n        self.softmax  = nn.Softmax(dim=-1)\n        self.weight = nn.Parameter(\n            torch.Tensor(out_channels, in_channels, kernel_size*kernel_size)\n        )\n        self.relu = nn.ReLU(True)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : fullconv value + input feature\n                attention: B X HW X 9\n            process:\n            reshape x > 2d\n            \n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        energy1 = torch.zeros((m_batchsize, height*width, height*width)).cuda()\n        proj_query = x.view(m_batchsize, C, -1).permute(0, 2, 1)\n        proj_key = x.view(m_batchsize, C, -1)\n        energy2 = torch.bmm(proj_query, proj_key)\n        for i in range(height*width):\n            for j in range(i,height*width):\n                #pdb.set_trace()\n                energy1.data[:,i,j] = self.dis(proj_query[:,i],proj_query[:,j])\n                energy1.data[:,j,i] = energy1.data[:,i,j]\n        \n        energy = energy1*energy2\n\n        #energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        maxk = max((1,9))  # 取top1准确率，若取top1和top9准确率改为max((1,9))\n        top9, ReceptiveField = energy.topk(maxk, -1, True, False)\n        top9 = top9*ReceptiveField\n        score = self.softmax(top9)\n        x_in = x.view(m_batchsize,-1,height*width)\n        out = x_in\n        for i in range(self.out_channels):\n            for j in range(height*width):\n                x_in = x_in[:,:,ReceptiveField[:,j,:]]*score[:,j,:] # The shape of x:B X C X 9\n                out[:,i,j] = torch.sum(x_in*self.weight[i].expand_as(x), dim=0)\n        out = self.relu(out.view(m_batchsize,C,height,width))\n\n        return out\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        #modules_body.append(FullConvRes(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Channel Attention Network (RCAN)\nclass RCAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(RCAN, self).__init__()\n        \n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n        modules_body.append(FullConvRes1())\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        #self.fcr = FullConvRes()\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n\n        res = self.body(x)\n        res += x\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/rcan3.py",
    "content": "from model import common\nimport torch\nimport torch.nn as nn\nfrom torch.autograd import Variable\nimport pdb\nimport numpy as np\nimport math\n\ndef make_model(args, parent=False):\n    return RCAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\nclass MSCALayer(nn.Module):\n    def __init__(self):\n        pass\n\nclass Dis(nn.Module):\n    def __init__(self, loss_type='L1', B=4):\n        super(Dis, self).__init__()\n        self.loss_type = loss_type\n        #self.loss = torch.zeros(B)\n        if self.loss_type == 'cos':\n            self.dot_product, self.square_sum_x, self.square_sum_y = torch.zeros(B), torch.zeros(B), torch.zeros(B)\n\n\n    def forward(self, x1, x2):\n\n        if self.loss_type=='L1':\n            return self.L1Loss(x1, x2)\n\n        if self.loss_type=='L2':\n            return self.L2Loss(x1, x2)\n\n        if self.loss_type=='cos':\n            return self.cosine_similarity(x1, x2)\n\n\n    def L1Loss(self, x1, x2):\n\n        loss = torch.sum(torch.abs(x1[:]-x2[:]), dim=1)\n        return loss\n\n    def L2Loss(self, x1, x2):\n\n        loss = torch.sum((x1[:]-x2[:]).pow(2), dim=1)\n        return loss\n\n    def bit_product_sum(self, x, y):\n        return sum([item[0] * item[1] for item in zip(x, y)])\n\n\n    def cosine_similarity(self, x, y, norm=True):\n        \"\"\" 计算两个向量x和y的余弦相似度 \"\"\"\n        assert len(x) == len(y), \"len(x) != len(y)\"\n\n        # method 1\n        #res = torch.tensor([[x[i] * y[i], x[i] * x[i], y[i] * y[i]] for i in range(len(x))])\n        #cos = sum(res[:, 0]) / (torch.sqrt(sum(res[:, 1])) * torch.sqrt(sum(res[:, 2])))\n\n        # method 2\n        # cos = self.bit_product_sum(x, y) / (torch.sqrt(self.bit_product_sum(x, x)) * torch.sqrt(self.bit_product_sum(y, y)))\n\n        #method 3\n        dot_product, square_sum_x, square_sum_y = self.dot_product, self.square_sum_x, self.square_sum_y\n        for i in range(x.size()[1]):\n            dot_product[:] += x[:,i] * y[:,i]\n            square_sum_x[:] += x[:,i] * x[:,i]\n            square_sum_y[:] += y[:,i] * y[:,i]\n        cos = dot_product / (torch.sqrt(square_sum_x) * torch.sqrt(square_sum_y))\n\n        return 0.5 * cos + 0.5 if norm else cos  # 归一化到[0, 1]区间内\n\nclass DAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self, in_dim):\n        super(DAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass SEDAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self, in_dim):\n        super(SEDAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n        self.conv_du = nn.Sequential(\n            nn.Conv2d(121, 11, 1, padding=0, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(11, 121, 1, padding=0, bias=True),\n        )\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n\n        # proj_query = x.view(m_batchsize, N, -1)\n        # proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        # energy = torch.bmm(proj_query, proj_key)\n        # energy = self.conv_du(energy.view(m_batchsize, -1, 1, 1)).view(m_batchsize, N, N)\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy1 = torch.zeros((m_batchsize, N, 1)).cuda()\n        for i in range(N):\n            energy1.data[:,i] = torch.sqrt(energy[:,i,i]).unsqueeze(1)\n        energy2 = energy1.permute(0, 2, 1)\n        energy = energy/energy1.expand_as(energy)\n        energy = energy/energy2.expand_as(energy)\n        energy = self.conv_du(energy.view(m_batchsize, -1, 1, 1)).view(m_batchsize, N, N)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass MSAM_Module(nn.Module):\n    \"\"\"MultiScale Sptial Attention\"\"\"\n    def __init__(self, in_dim):\n        super(MSAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n        # self.conv_du = nn.Sequential(\n        #     nn.Conv2d(2304*2304, 48, 1, padding=0, bias=True),\n        #     nn.ReLU(inplace=True),\n        #     nn.Conv2d(48, 2304*2304, 1, padding=0, bias=True),\n        # )\n        self.conv0 = nn.Conv2d(in_dim, in_dim//16, 1, 1, 0)\n        #self.conv1 = nn.Conv2d(in_dim/2, in_dim/2, 3, 1, 1)\n        self.conv = nn.Conv2d(in_dim//16, in_dim//16, 3, 1, 1)\n        self.atten_conv = nn.Conv2d(in_dim//16,1,1,1,0)\n\n        self.last_conv = nn.Conv2d(in_dim//16*4, in_dim, 1, 1, 0)\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n        self.sigmoid = nn.Sigmoid()\n        self.relu = nn.ReLU(True)\n    \n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X HW X HW\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        x1 = self.multi_scale(x)\n\n        proj_query = x1.view(m_batchsize, -1, C*height*width//16)\n        proj_key = x1.view(m_batchsize, -1, C*height*width//16).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        #energy = self.conv_du(energy.view(m_batchsize, -1, 1, 1)).view(m_batchsize, H*W, H*W)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n\n        attention = self.softmax(energy_new)\n        proj_value = x1.view(m_batchsize, -1, C*height*width//16)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, -1, height, width)\n        out = self.last_conv(out)\n\n        out = self.gamma*out + x\n        #out = out.view(m_batchsize, -1, height, width)\n        return out\n\n    def attention(self, x):\n        out = self.sigmoid(self.atten_conv(x))\n\n        return out*x+x\n\n    def one_scale(self, x, scale=2):\n        m_batchsize, C, height, width = x.size()\n        dowsample = nn.AvgPool2d(scale, stride=scale)\n        upsample = nn.Upsample(scale_factor=scale, mode='nearest')\n        #pdb.set_trace()\n        x = dowsample(x)\n        x = self.relu(self.conv(x))\n        x = upsample(x)\n        x = self.attention(x)\n\n        return x\n\n    def multi_scale(self, x):\n        x = self.relu(self.conv0(x))\n        out = self.conv(x)\n        out = out.unsqueeze(1)\n        scale_list = [2,3,4]\n        for scale in scale_list:\n            x1 = self.one_scale(x ,scale).unsqueeze(1)\n            #pdb.set_trace()\n            out = torch.cat([out, x1], 1)\n\n        return out\n\nclass SAM_Module(nn.Module):\n    \"\"\"SE Sptial Attention\"\"\"\n    def __init__(self, in_dim):\n        super(SAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n        # self.conv_du = nn.Sequential(\n        #     nn.Conv2d(2304*2304, 48, 1, padding=0, bias=True),\n        #     nn.ReLU(inplace=True),\n        #     nn.Conv2d(48, 2304*2304, 1, padding=0, bias=True),\n        # )\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n        #self.sigmoid = nn.Sigmoid()\n        #self.conv0 = nn.Conv2d(in_dim*4, 1, 1, 1, 0)\n        self.pad1 = nn.ReplicationPad2d((0,0,1,0))\n        self.pad2 = nn.ReplicationPad2d((1,0,0,0))\n        self.pixel_shuffle = nn.PixelShuffle(2)\n    \n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X HW X HW\n        \"\"\"\n        x,top,left = self.depixel_shuffle(x)\n        m_batchsize, C, height, width = x.size()\n\n        proj_query = x.view(m_batchsize, -1, height*width)\n        proj_key = x.view(m_batchsize, -1, height*width).permute(0, 2, 1)\n        energy = torch.bmm(proj_key, proj_query)\n        energy1 = torch.zeros((m_batchsize, height*width, 1)).cuda()\n        for i in range(height*width):\n            energy1.data[:,i] = torch.sqrt(energy[:,i,i]).unsqueeze(1)\n        energy = energy/energy1.expand_as(energy)\n        energy1 = energy1.permute(0, 2, 1)\n        #energy = energy/energy1.expand_as(energy)\n        energy = energy/energy1.expand_as(energy)\n              #energy = self.conv_du(energy.view(m_batchsize, -1, 1, 1)).view(m_batchsize, H*W, H*W)\n        #energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n\n        energy = self.softmax(energy)\n        #attention = self.absmax(energy_new)\n        #proj_value = x.view(m_batchsize, -1, height*width)\n\n        out = torch.bmm(proj_query, energy.permute(0, 2, 1))\n        out = out.view(m_batchsize, -1, height, width)\n        #atten = self.conv0(x)\n\n\n        out = self.gamma*out + x\n        out = self.pixel_shuffle(out)\n        if top != 0:\n            out = out[:,:,1:,:]\n        if left != 0:\n            out = out[:,:,:,1:]\n        #out = out.view(m_batchsize, -1, height, width)\n        return out\n\n    def depixel_shuffle(self, x, upscale_factor=2):    \n        batch_size, channels, height, width = x.size()\n        pdb.set_trace()\n        out_channels = channels * (upscale_factor ** 2)              \n        top,left = 0,0\n        if height%2==1:\n            x = self.pad1(x)\n            top=1\n        if width%2==1:\n            x = self.pad2(x)\n            left=1\n\n        height = math.ceil(height / upscale_factor)\n        width = math.ceil(width / upscale_factor)\n\n        x_view = x.contiguous().view(\n            batch_size, channels, height, upscale_factor, width, upscale_factor)\n\n        shuffle_out = x_view.permute(0, 1, 3, 5, 2, 4).contiguous()\n        return shuffle_out.view(batch_size, out_channels, height, width),top,left\n        \n    def squaremax(self, x, dim=-1):\n        x_square = x.pow(2)\n        x_sum = torch.sum(x_square, dim=dim, keepdim=True)\n        s = x_square / x_sum\n        return s\n        \n    def logmax(self,x):\n        x_log = torch.log(x+1)\n        x_sum = torch.sum(x_log, dim=-1, keepdim=True)\n        s = x_log / x_sum\n        return s\n        \n    def absmax(self,x):\n        x_abs = torch.abs(x)\n        x_sum = torch.sum(x_abs, dim=-1, keepdim=True)\n        s = x_abs / x_sum\n        return s\n\nclass SECAM_Module(nn.Module):\n    def __init__(self, in_dim):\n        super(SECAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n        self.conv_du = nn.Sequential(\n            nn.Conv2d(4096, 16, 1, padding=0, bias=True),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(16, 4096, 1, padding=0, bias=True),\n        )\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    \n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X C X C\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        pdb.set_trace()\n\n        # proj_query = x.view(m_batchsize, C, -1)\n        # proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)\n        # energy = torch.bmm(proj_query, proj_key)\n        proj_query = x.contiguous().view(m_batchsize, C, -1)\n        proj_key = x.contiguous().view(m_batchsize, C, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy1 = torch.zeros((m_batchsize, C, 1)).cuda()\n        for i in range(C):\n            energy1.data[:,i] = torch.sqrt(energy[:,i,i]).unsqueeze(1)\n        energy2 = energy1.permute(0, 2, 1)\n        energy = energy/energy1.expand_as(energy)\n        energy = energy/energy2.expand_as(energy)\n        energy = self.conv_du(energy.view(m_batchsize, -1, 1, 1)).view(m_batchsize, C, C)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, C, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, -1, height, width)\n\n        out = self.gamma*out + x\n        #out = out.view(m_batchsize, -1, height, width)\n        return out\n\n\nclass LAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self, in_dim):\n        super(LAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.dis = Dis('L1')\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.energy1 =  torch.zeros((4, 11, 11)).cuda()\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n            process:\n            reshape x > 2d\n            任意两行层特征，求关系置信度。关系置信度定义为 距离求反\n            得到置信度矩阵\n            矩阵相乘，乘上尺度因子，再与输入相加\n            \n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        energy1 = torch.zeros((4, 11, 11)).cuda()\n\n        #energy2 = Variable(energy1,requires_grad=True)\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy2 = torch.bmm(proj_query, proj_key)\n        for i in range(N):\n                #a = []\n            for j in range(i,N):\n                #pdb.set_trace()\n                #a.append(self.dis(proj_query[:][i],proj_query[:][j]))\n                energy1.data[:,i,j] = self.dis(proj_query[:,i],proj_query[:,j])\n\n                energy1.data[:,j,i] = energy1.data[:,i,j]\n            #energy1.append(a)\n        \n        energy = energy1*energy2\n\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n        \n        \nclass GAM_Module(nn.Module):\n    \"\"\" Global\n    attention module\"\"\"\n    def __init__(self, in_dim):\n        super(GAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.conv = nn.Conv3d(1, 1, 3, 1, 1)\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.softmax  = nn.Softmax(dim=-1)\n        self.sigmoid = nn.Sigmoid()\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        out = x.unsqueeze(1)\n        out = self.sigmoid(self.conv(out))\n        \n        # proj_query = x.view(m_batchsize, N, -1)\n        # proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        # energy = torch.bmm(proj_query, proj_key)\n        # energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        # attention = self.softmax(energy_new)\n        # proj_value = x.view(m_batchsize, N, -1)\n\n        # out = torch.bmm(attention, proj_value)\n        # out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out\n        out = out.view(m_batchsize, -1, height, width)\n        x = x * out + x\n        return x\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \\\n            for _ in range(n_resblocks)]\n        #modules_body.append(RCMSAN())\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Channel Attention Network (RCAN)\nclass RCAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(RCAN, self).__init__()\n        \n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.ca = SECAM_Module(n_feats)\n        self.sa = SAM_Module(n_feats)\n        self.da = SEDAM_Module(n_feats)\n        self.last_conv = nn.Conv2d(n_feats*11, n_feats, 3, 1, 1)\n        self.lastc = nn.Conv2d(n_feats*3, n_feats, 3, 1, 1)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        res = x\n        #pdb.set_trace()\n        for name, midlayer in self.body._modules.items():\n            res = midlayer(res)\n            #print(name)\n            if name=='0':\n                res1 = res.unsqueeze(1)\n            else:\n                res1 = torch.cat([res.unsqueeze(1),res1],1)\n        #res = self.body(x)\n        out1 = res\n        #res3 = res.unsqueeze(1)\n        #res = torch.cat([res1,res3],1)\n        res = self.da(res1)\n        out2 = self.last_conv(res)\n\n        out1 = self.sa(out1)\n        out3 = self.ca(out1)\n        out = torch.cat([out1, out2, out3], 1)\n        res = self.lastc(out)\n        \n        res += x\n        #res = self.ga(res)\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/rcan4.py",
    "content": "from model import common\nimport torch\nimport torch.nn as nn\nimport pdb\n\ndef make_model(args, parent=False):\n    return RCAN(args)\n\n## Channel Attention (CA) Layer\nclass CALayer(nn.Module):\n    def __init__(self, channel, reduction=16):\n        super(CALayer, self).__init__()\n        # global average pooling: feature --> point\n        self.avg_pool = nn.AdaptiveAvgPool2d(1)\n        # feature channel downscale and upscale --> channel weight\n        self.conv_du = nn.Sequential(\n                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),\n                nn.ReLU(inplace=True),\n                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),\n                nn.Sigmoid()\n        )\n\n    def forward(self, x):\n        y = self.avg_pool(x)\n        y = self.conv_du(y)\n        return x * y\n\nclass DAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self, in_dim):\n        super(DAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass GAM_Module(nn.Module):\n    \"\"\" Global\n    attention module\"\"\"\n    def __init__(self, in_dim):\n        super(GAM_Module, self).__init__()\n        self.chanel_in = in_dim\n\n\n        self.conv = nn.Conv3d(1, 1, 3, 1, 1)\n        self.gamma = nn.Parameter(torch.zeros(1))\n        #self.softmax  = nn.Softmax(dim=-1)\n        self.sigmoid = nn.Sigmoid()\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, C, height, width = x.size()\n        out = x.unsqueeze(1)\n        out = self.sigmoid(self.conv(out))\n        \n        # proj_query = x.view(m_batchsize, N, -1)\n        # proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        # energy = torch.bmm(proj_query, proj_key)\n        # energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        # attention = self.softmax(energy_new)\n        # proj_value = x.view(m_batchsize, N, -1)\n\n        # out = torch.bmm(attention, proj_value)\n        # out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out\n        out = out.view(m_batchsize, -1, height, width)\n        x = x * out + x\n        return x\n\n## Residual Channel Attention Block (RCAB)\nclass RCAB(nn.Module):\n    def __init__(\n        self, conv, n_feat, kernel_size, reduction,\n        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):\n\n        super(RCAB, self).__init__()\n        modules_body = []\n        for i in range(2):\n            modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))\n            if bn: modules_body.append(nn.BatchNorm2d(n_feat))\n            if i == 0: modules_body.append(act)\n        modules_body.append(CALayer(n_feat, reduction))\n        self.body = nn.Sequential(*modules_body)\n        self.res_scale = res_scale\n\n    def forward(self, x):\n        res = self.body(x)\n        #res = self.body(x).mul(self.res_scale)\n        res += x\n        return res\n\n## Residual Group (RG)\nclass ResidualGroup(nn.Module):\n    def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):\n        super(ResidualGroup, self).__init__()\n        modules_body = []\n        modules_body = [\n            RCAB(\n                conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \\\n            for _ in range(n_resblocks)]\n        modules_body.append(conv(n_feat, n_feat, kernel_size))\n        self.body = nn.Sequential(*modules_body)\n\n    def forward(self, x):\n        res = self.body(x)\n        res += x\n        return res\n\n## Residual Channel Attention Network (RCAN)\nclass RCAN(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(RCAN, self).__init__()\n        \n        n_resgroups = args.n_resgroups\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3\n        reduction = args.reduction \n        scale = args.scale[0]\n        act = nn.ReLU(True)\n        \n        # RGB mean for DIV2K\n        rgb_mean = (0.4488, 0.4371, 0.4040)\n        rgb_std = (1.0, 1.0, 1.0)\n        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)\n        \n        # define head module\n        modules_head = [conv(args.n_colors, n_feats, kernel_size)]\n\n        # define body module\n        modules_body = [\n            ResidualGroup(\n                conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \\\n            for _ in range(n_resgroups)]\n\n        modules_body.append(conv(n_feats, n_feats, kernel_size))\n\n        # define tail module\n        modules_tail = [\n            common.Upsampler(conv, scale, n_feats, act=False),\n            conv(n_feats, args.n_colors, kernel_size)]\n\n        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)\n\n        self.head = nn.Sequential(*modules_head)\n        self.body = nn.Sequential(*modules_body)\n        self.ga = GAM_Module(n_feats)\n        self.da = DAM_Module(n_feats)\n        self.last_conv1 = nn.Conv2d(n_feats*11, n_feats, 3, 1, 1)\n        self.ga_conv = nn.Conv2d(n_feats*11, n_feats, 3, 1, 1)\n        self.last = nn.Conv2d(n_feats*2, n_feats, 3, 1, 1)\n        self.tail = nn.Sequential(*modules_tail)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        x = self.head(x)\n        res = x\n        #pdb.set_trace()\n        for name, midlayer in self.body._modules.items():\n            res = midlayer(res)\n\n            #print(name)\n            if name=='0':\n                out1 = self.ga(res)\n                res1 = res.unsqueeze(1)\n            elif name in ['4','5','9','10']:\n                out1 = torch.cat([self.ga(res),out1],1)\n                res1 = torch.cat([res.unsqueeze(1),res1],1)\n            else:\n                out1 = torch.cat([self.ga(res),out1],1)\n                res1 = torch.cat([res.unsqueeze(1),res1],1)\n        #res = self.body(x)\n        out1 = self.ga_conv(out1)\n        #res3 = res.unsqueeze(1)\n        #res = torch.cat([res1,res3],1)\n        res = self.da(res1)\n        out2 = self.last_conv1(res)\n\n        #out1 = self.ga(out1)\n        out = torch.cat([out1, out2], 1)\n        res = self.last(out)\n        \n        res += x\n        #res = self.ga(res)\n\n        x = self.tail(res)\n        x = self.add_mean(x)\n\n        return x \n\n    def load_state_dict(self, state_dict, strict=False):\n        own_state = self.state_dict()\n        for name, param in state_dict.items():\n            if name in own_state:\n                if isinstance(param, nn.Parameter):\n                    param = param.data\n                try:\n                    own_state[name].copy_(param)\n                except Exception:\n                    if name.find('tail') >= 0:\n                        print('Replace pre-trained upsampler to new one...')\n                    else:\n                        raise RuntimeError('While copying the parameter named {}, '\n                                           'whose dimensions in the model are {} and '\n                                           'whose dimensions in the checkpoint are {}.'\n                                           .format(name, own_state[name].size(), param.size()))\n            elif strict:\n                if name.find('tail') == -1:\n                    raise KeyError('unexpected key \"{}\" in state_dict'\n                                   .format(name))\n\n        if strict:\n            missing = set(own_state.keys()) - set(state_dict.keys())\n            if len(missing) > 0:\n                raise KeyError('missing keys in state_dict: \"{}\"'.format(missing))"
  },
  {
    "path": "src/model/rdn.py",
    "content": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return RDN(args)\n\nclass RDB_Conv(nn.Module):\n    def __init__(self, inChannels, growRate, kSize=3):\n        super(RDB_Conv, self).__init__()\n        Cin = inChannels\n        G  = growRate\n        self.conv = nn.Sequential(*[\n            nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),\n            nn.ReLU()\n        ])\n\n    def forward(self, x):\n        out = self.conv(x)\n        return torch.cat((x, out), 1)\n\nclass RDB(nn.Module):\n    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):\n        super(RDB, self).__init__()\n        G0 = growRate0\n        G  = growRate\n        C  = nConvLayers\n        \n        convs = []\n        for c in range(C):\n            convs.append(RDB_Conv(G0 + c*G, G))\n        self.convs = nn.Sequential(*convs)\n        \n        # Local Feature Fusion\n        self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)\n\n    def forward(self, x):\n        return self.LFF(self.convs(x)) + x\n\nclass RDN(nn.Module):\n    def __init__(self, args):\n        super(RDN, self).__init__()\n        r = args.scale[0]\n        G0 = args.G0\n        kSize = args.RDNkSize\n\n        # number of RDB blocks, conv layers, out channels\n        self.D, C, G = {\n            'A': (20, 6, 32),\n            'B': (16, 8, 64),\n        }[args.RDNconfig]\n\n        # Shallow feature extraction net\n        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)\n        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n\n        # Redidual dense blocks and dense feature fusion\n        self.RDBs = nn.ModuleList()\n        for i in range(self.D):\n            self.RDBs.append(\n                RDB(growRate0 = G0, growRate = G, nConvLayers = C)\n            )\n\n        # Global Feature Fusion\n        self.GFF = nn.Sequential(*[\n            nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),\n            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n        ])\n\n        # Up-sampling net\n        if r == 2 or r == 3:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(r),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        elif r == 4:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        else:\n            raise ValueError(\"scale must be 2 or 3 or 4.\")\n\n    def forward(self, x):\n        f__1 = self.SFENet1(x)\n        x  = self.SFENet2(f__1)\n\n        RDBs_out = []\n        for i in range(self.D):\n            x = self.RDBs[i](x)\n            RDBs_out.append(x)\n\n        x = self.GFF(torch.cat(RDBs_out,1))\n        x += f__1\n\n        return self.UPNet(x)"
  },
  {
    "path": "src/model/rdn1.py",
    "content": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return RDN(args)\n\nclass RDB_Conv(nn.Module):\n    def __init__(self, inChannels, growRate, kSize=(3,3,3)):\n        super(RDB_Conv, self).__init__()\n        Cin = inChannels\n        G  = growRate\n        self.conv = nn.Sequential(*[\n            nn.Conv3d(Cin, G, kSize, padding=(1,1,1), stride=1),\n            nn.ReLU()\n        ])\n        self.da = DAM_Module()\n\n    def forward(self, x):\n        x = self.da(x)\n        out = self.conv(x)\n        out = torch.cat((x, out), 1)   \n\n        return out\n\nclass DAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self):\n        super(DAM_Module, self).__init__()\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass RDB(nn.Module):\n    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):\n        super(RDB, self).__init__()\n        G0 = growRate0\n        G  = 1\n        C  = nConvLayers\n        \n        convs = []\n        for c in range(C):\n            convs.append(RDB_Conv(G0 + c*G, G))\n        self.convs = nn.Sequential(*convs)\n\n        self.da = DAM_Module()\n        \n        # Local Feature Fusion\n        self.LFF = nn.Conv3d(G0 + C*G, G0, 1, padding=0, stride=1)\n\n    def forward(self, x):\n        x = self.da(x)\n        out = self.LFF(self.convs(x)) + x\n        return out\n\nclass RDN(nn.Module):\n    def __init__(self, args):\n        super(RDN, self).__init__()\n        r = args.scale[0]\n        G0 = args.G0\n        kSize = args.RDNkSize\n\n        # number of RDB blocks, conv layers, out channels\n        self.D, C, G = {\n            'A': (20, 6, 32),\n            'B': (16, 8, 64),\n        }[args.RDNconfig]\n\n        # Shallow feature extraction net\n        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)\n        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n\n        # Redidual dense blocks and dense feature fusion\n        self.RDBs = nn.ModuleList()\n        for i in range(self.D):\n            self.RDBs.append(\n                RDB(growRate0 = G0, growRate = G, nConvLayers = C)\n            )\n        \n        self.da = DAM_Module()\n        # Global Feature Fusion\n        self.GFF = nn.Sequential(*[\n            nn.Conv3d(self.D * G0, G0, 1, padding=0, stride=1),\n            nn.Conv3d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n        ])\n\n        # Up-sampling net\n        if r == 2 or r == 3:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(r),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        elif r == 4:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        else:\n            raise ValueError(\"scale must be 2 or 3 or 4.\")\n\n    def forward(self, x):\n        f__1 = self.SFENet1(x)\n        x  = self.SFENet2(f__1).unsqueeze(1)\n\n        RDBs_out = []\n        for i in range(self.D):\n            x = self.RDBs[i](x)\n            RDBs_out.append(x)\n\n        x = torch.cat(RDBs_out,1)\n        B,N,C,H,W = x.size()\n        x = self.da(x)\n\n        x = self.GFF(x)\n        x = x.view(B,N*C,H,W)\n        x += f__1\n\n        return self.UPNet(x)\n"
  },
  {
    "path": "src/model/rdn2.py",
    "content": "# Residual Dense Network for Image Super-Resolution\n# https://arxiv.org/abs/1802.08797\n\nfrom model import common\n\nimport torch\nimport torch.nn as nn\n\n\ndef make_model(args, parent=False):\n    return RDN(args)\n\nclass RDB_Conv(nn.Module):\n    def __init__(self, inChannels, growRate, kSize=3):\n        super(RDB_Conv, self).__init__()\n        Cin = inChannels\n        G  = growRate\n        self.conv = nn.Sequential(*[\n            nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),\n            nn.ReLU()\n        ])\n        self.da = DAM_Module()\n\n    def forward(self, x):\n        B,N,C,H,W = x.size()\n        x = self.da(x)\n        x = x.view(B,N*C,H,W)\n        out = self.conv(x).unsqueeze(1)\n        out = torch.cat((x, out), 1)   \n\n        return out\n\nclass DAM_Module(nn.Module):\n    \"\"\" Deep attention module\"\"\"\n    def __init__(self):\n        super(DAM_Module, self).__init__()\n\n\n        self.gamma = nn.Parameter(torch.zeros(1))\n        self.softmax  = nn.Softmax(dim=-1)\n    def forward(self,x):\n        \"\"\"\n            inputs :\n                x : input feature maps( B X N X C X H X W)\n            returns :\n                out : attention value + input feature\n                attention: B X N X N\n        \"\"\"\n        m_batchsize, N, C, height, width = x.size()\n        proj_query = x.view(m_batchsize, N, -1)\n        proj_key = x.view(m_batchsize, N, -1).permute(0, 2, 1)\n        energy = torch.bmm(proj_query, proj_key)\n        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy\n        attention = self.softmax(energy_new)\n        proj_value = x.view(m_batchsize, N, -1)\n\n        out = torch.bmm(attention, proj_value)\n        out = out.view(m_batchsize, N, C, height, width)\n\n        out = self.gamma*out + x\n        out = out.view(m_batchsize, -1, height, width)\n        return out\n\nclass RDB(nn.Module):\n    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):\n        super(RDB, self).__init__()\n        G0 = growRate0\n        G  = growRate\n        C  = nConvLayers\n        \n        convs = []\n        for c in range(C):\n            convs.append(RDB_Conv(G0 + c*G, G))\n        self.convs = nn.Sequential(*convs)\n\n        self.da = DAM_Module()\n        \n        # Local Feature Fusion\n        self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)\n\n    def forward(self, x):\n        B,N,C,H,W = x.size()\n        out = self.da(x)\n        out = out.view(B,N*C,H,W)\n        out = self.LFF(self.convs(x)).unsqueeze(1) + x\n        return out\n\nclass RDN(nn.Module):\n    def __init__(self, args):\n        super(RDN, self).__init__()\n        r = args.scale[0]\n        G0 = args.G0\n        kSize = args.RDNkSize\n\n        # number of RDB blocks, conv layers, out channels\n        self.D, C, G = {\n            'A': (20, 6, 32),\n            'B': (16, 8, 64),\n        }[args.RDNconfig]\n\n        # Shallow feature extraction net\n        self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)\n        self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n\n        # Redidual dense blocks and dense feature fusion\n        self.RDBs = nn.ModuleList()\n        for i in range(self.D):\n            self.RDBs.append(\n                RDB(growRate0 = G0, growRate = G, nConvLayers = C)\n            )\n\n        self.da = DAM_Module()\n        # Global Feature Fusion\n        self.GFF = nn.Sequential(*[\n            nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),\n            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)\n        ])\n\n        # Up-sampling net\n        if r == 2 or r == 3:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(r),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        elif r == 4:\n            self.UPNet = nn.Sequential(*[\n                nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),\n                nn.PixelShuffle(2),\n                nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)\n            ])\n        else:\n            raise ValueError(\"scale must be 2 or 3 or 4.\")\n\n    def forward(self, x):\n        f__1 = self.SFENet1(x)\n        x  = self.SFENet2(f__1).unsqueeze(1)\n\n        RDBs_out = []\n        for i in range(self.D):\n            x = self.RDBs[i](x)\n            RDBs_out.append(x)\n\n        x = torch.cat(RDBs_out,1)\n        B,N,C,H,W = x.size()\n        x = self.da(x)\n        x = x.view(B,N*C,H,W)\n\n        x = self.GFF(x)\n        x += f__1\n\n        return self.UPNet(x)\n"
  },
  {
    "path": "src/model/vdsr.py",
    "content": "from model import common\n\nimport torch.nn as nn\nimport torch.nn.init as init\n\nurl = {\n    'r20f64': ''\n}\n\ndef make_model(args, parent=False):\n    return VDSR(args)\n\nclass VDSR(nn.Module):\n    def __init__(self, args, conv=common.default_conv):\n        super(VDSR, self).__init__()\n\n        n_resblocks = args.n_resblocks\n        n_feats = args.n_feats\n        kernel_size = 3 \n        self.url = url['r{}f{}'.format(n_resblocks, n_feats)]\n        self.sub_mean = common.MeanShift(args.rgb_range)\n        self.add_mean = common.MeanShift(args.rgb_range, sign=1)\n\n        def basic_block(in_channels, out_channels, act):\n            return common.BasicBlock(\n                conv, in_channels, out_channels, kernel_size,\n                bias=True, bn=False, act=act\n            )\n\n        # define body module\n        m_body = []\n        m_body.append(basic_block(args.n_colors, n_feats, nn.ReLU(True)))\n        for _ in range(n_resblocks - 2):\n            m_body.append(basic_block(n_feats, n_feats, nn.ReLU(True)))\n        m_body.append(basic_block(n_feats, args.n_colors, None))\n\n        self.body = nn.Sequential(*m_body)\n\n    def forward(self, x):\n        x = self.sub_mean(x)\n        res = self.body(x)\n        res += x\n        x = self.add_mean(res)\n\n        return x \n\n"
  },
  {
    "path": "src/option.py",
    "content": "import argparse\nimport template\n\nparser = argparse.ArgumentParser(description='EDSR and MDSR')\n\nparser.add_argument('--debug', action='store_true',\n                    help='Enables debug mode')\nparser.add_argument('--template', default='.',\n                    help='You can set various templates in option.py')\n\n# Hardware specifications\nparser.add_argument('--n_threads', type=int, default=1,\n                    help='number of threads for data loading')\nparser.add_argument('--cpu', action='store_true',\n                    help='use cpu only')\nparser.add_argument('--n_GPUs', type=int, default=2,\n                    help='number of GPUs')\nparser.add_argument('--seed', type=int, default=1,\n                    help='random seed')\n\n# Data specifications\nparser.add_argument('--dir_data', type=str, default='/media/zrh/备份/AIM/X4',\n                    help='dataset directory')\nparser.add_argument('--dir_demo', type=str, default='../test',\n                    help='demo image directory')\nparser.add_argument('--data_train', type=str, default='DIV2K',\n                    help='train dataset name')\nparser.add_argument('--data_test', type=str, default='Set20',\n                    help='test dataset name')\nparser.add_argument('--data_range', type=str, default='1-18000/18001-18999',\n                    help='train/test data range')\nparser.add_argument('--ext', type=str, default='sep',\n                    help='dataset file extension')\nparser.add_argument('--scale', type=str, default='4',\n                    help='super resolution scale')\nparser.add_argument('--patch_size', type=int, default=192,\n                    help='output patch size')\nparser.add_argument('--rgb_range', type=int, default=255,\n                    help='maximum value of RGB')\nparser.add_argument('--n_colors', type=int, default=3,\n                    help='number of color channels to use')\nparser.add_argument('--chop', action='store_true',\n                    help='enable memory-efficient forward')\nparser.add_argument('--no_augment', action='store_true',\n                    help='do not use data augmentation')\n\n# Model specifications\nparser.add_argument('--model', default='MatrixModel',\n                    help='model name')\n\nparser.add_argument('--act', type=str, default='relu',\n                    help='activation function')\nparser.add_argument('--pre_train', type=str, default='',\n                    help='pre-trained model directory')\nparser.add_argument('--extend', type=str, default='.',\n                    help='pre-trained model directory')\nparser.add_argument('--n_resblocks', type=int, default=20,\n                    help='number of residual blocks')\nparser.add_argument('--n_feats', type=int, default=128,\n                    help='number of feature maps')\nparser.add_argument('--block', type=str, default='BASIC',\n                    choices=('BASIC','EctBASIC','EctBOTTLENECK','CAEctBASIC'),\n                    help='type of residual blocks')\nparser.add_argument('--res_scale', type=float, default=1,\n                    help='residual scaling')\nparser.add_argument('--shift_mean', default=True,\n                    help='subtract pixel mean from the input')\nparser.add_argument('--dilation', action='store_true',\n                    help='use dilated convolution')\nparser.add_argument('--precision', type=str, default='single',\n                    choices=('single', 'half'),\n                    help='FP precision for test (single | half)')\n\n# Option for Residual dense network (RDN)\nparser.add_argument('--G0', type=int, default=64,\n                    help='default number of filters. (Use in RDN)')\nparser.add_argument('--RDNkSize', type=int, default=3,\n                    help='default kernel size. (Use in RDN)')\nparser.add_argument('--RDNconfig', type=str, default='B',\n                    help='parameters config of RDN. (Use in RDN)')\n\n# Option for Residual channel attention network (RCAN)\nparser.add_argument('--n_resgroups', type=int, default=10,\n                    help='number of residual groups')\nparser.add_argument('--reduction', type=int, default=16,\n                    help='number of feature maps reduction')\n\n# Training specifications\nparser.add_argument('--reset', action='store_true',\n                    help='reset the training')\nparser.add_argument('--test_every', type=int, default=1000,\n                    help='do test per every N batches')\nparser.add_argument('--epochs', type=int, default=400,\n                    help='number of epochs to train')\nparser.add_argument('--batch_size', type=int, default=16,\n                    help='input batch size for training')\nparser.add_argument('--split_batch', type=int, default=1,\n                    help='split the batch into smaller chunks')\nparser.add_argument('--self_ensemble', action='store_true',\n                    help='use self-ensemble method for test')\nparser.add_argument('--test_only', action='store_true',\n                    help='set this option to test the model')\nparser.add_argument('--gan_k', type=int, default=1,\n                    help='k value for adversarial loss')\n\n# Optimization specifications\nparser.add_argument('--lr', type=float, default=1e-4,\n                    help='learning rate')\nparser.add_argument('--decay', type=str, default='200',\n                    help='learning rate decay type')\nparser.add_argument('--gamma', type=float, default=0.5,\n                    help='learning rate decay factor for step decay')\nparser.add_argument('--optimizer', default='ADAM',\n                    choices=('SGD', 'ADAM', 'RMSprop'),\n                    help='optimizer to use (SGD | ADAM | RMSprop)')\nparser.add_argument('--momentum', type=float, default=0.9,\n                    help='SGD momentum')\nparser.add_argument('--betas', type=tuple, default=(0.9, 0.999),\n                    help='ADAM beta')\nparser.add_argument('--epsilon', type=float, default=1e-8,\n                    help='ADAM epsilon for numerical stability')\nparser.add_argument('--weight_decay', type=float, default=0,\n                    help='weight decay')\nparser.add_argument('--gclip', type=float, default=0,\n                    help='gradient clipping threshold (0 = no clipping)')\n\n# Loss specifications\nparser.add_argument('--loss', type=str, default='1*MSE',\n                    help='loss function configuration')\nparser.add_argument('--skip_threshold', type=float, default='1e8',\n                    help='skipping batch that has large error')\n\n# Log specifications\nparser.add_argument('--save', type=str, default='test',\n                    help='file name to save')\nparser.add_argument('--load', type=str, default='',\n                    help='file name to load')\nparser.add_argument('--resume', type=int, default=0,\n                    help='resume from specific checkpoint')\nparser.add_argument('--save_models', action='store_true',\n                    help='save all intermediate models')\nparser.add_argument('--print_every', type=int, default=100,\n                    help='how many batches to wait before logging training status')\nparser.add_argument('--save_results', action='store_true',\n                    help='save output results')\nparser.add_argument('--save_gt', action='store_true',\n                    help='save low-resolution and high-resolution images together')\n\nargs = parser.parse_args()\ntemplate.set_template(args)\n\nargs.scale = list(map(lambda x: int(x), args.scale.split('+')))\nargs.data_train = args.data_train.split('+')\nargs.data_test = args.data_test.split('+')\n\nif args.epochs == 0:\n    args.epochs = 1e8\n\nfor arg in vars(args):\n    if vars(args)[arg] == 'True':\n        vars(args)[arg] = True\n    elif vars(args)[arg] == 'False':\n        vars(args)[arg] = False\n\n"
  },
  {
    "path": "src/template.py",
    "content": "def set_template(args):\n    # Set the templates here\n    if args.template.find('jpeg') >= 0:\n        args.data_train = 'DIV2K_jpeg'\n        args.data_test = 'DIV2K_jpeg'\n        args.epochs = 200\n        args.decay = '100'\n\n    if args.template.find('EDSR_paper') >= 0:\n        args.model = 'EDSR'\n        args.n_resblocks = 32\n        args.n_feats = 256\n        args.res_scale = 0.1\n\n    if args.template.find('MDSR') >= 0:\n        args.model = 'MDSR'\n        args.patch_size = 48\n        args.epochs = 650\n\n    if args.template.find('DDBPN') >= 0:\n        args.model = 'DDBPN'\n        args.patch_size = 128\n        args.scale = '4'\n\n        args.data_test = 'Set5'\n\n        args.batch_size = 20\n        args.epochs = 1000\n        args.decay = '500'\n        args.gamma = 0.1\n        args.weight_decay = 1e-4\n\n        args.loss = '1*MSE'\n\n    if args.template.find('GAN') >= 0:\n        args.epochs = 200\n        args.lr = 5e-5\n        args.decay = '150'\n\n\n    if args.template.find('RCAN') >= 0:\n        args.model = 'RCAN'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('HAN') >= 0:\n        args.model = 'HAN'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('RCAN2') >= 0:\n        args.model = 'RCAN2'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('RCAN3') >= 0:\n        args.model = 'RCAN3'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('RCAN4') >= 0:\n        args.model = 'RCAN4'\n        args.n_resgroups = 10\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.chop = True\n\n    if args.template.find('VDSR') >= 0:\n        args.model = 'VDSR'\n        args.n_resblocks = 20\n        args.n_feats = 64\n        args.patch_size = 41\n        args.lr = 1e-1\n\n"
  },
  {
    "path": "src/trainer.py",
    "content": "import os\nimport math\nfrom decimal import Decimal\n\nimport utility\n\nimport torch\nimport torch.nn.utils as utils\nfrom tqdm import tqdm\nimport pdb\n\nclass Trainer():\n    def __init__(self, args, loader, my_model, my_loss, ckp):\n        self.args = args\n        self.scale = args.scale\n\n        self.ckp = ckp\n        self.loader_train = loader.loader_train\n        self.loader_test = loader.loader_test\n        self.model = my_model\n        self.loss = my_loss\n        self.optimizer = utility.make_optimizer(args, self.model)\n\n        if self.args.load != '':\n            self.optimizer.load(ckp.dir, epoch=len(ckp.log))\n            #print('aaaaaaaaaaaaaaaaaaaaaaaaaaaaa')\n\n        self.error_last = 1e8\n\n    def train(self):\n        self.loss.step()\n        #print(self.optimizer.get_last_epoch())\n        epoch = self.optimizer.get_last_epoch() + 1\n        #pdb.set_trace\n        lr = self.optimizer.get_lr()\n\n        self.ckp.write_log(\n            '[Epoch {}]\\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))\n        )\n        self.loss.start_log()\n        self.model.train()\n\n        timer_data, timer_model = utility.timer(), utility.timer()\n        # TEMP\n        self.loader_train.dataset.set_scale(0)\n        for batch, (lr, hr, _,) in enumerate(self.loader_train):\n            lr, hr = self.prepare(lr, hr)\n            timer_data.hold()\n            timer_model.tic()\n\n            self.optimizer.zero_grad()\n            sr = self.model(lr, 0)\n            loss = self.loss(sr, hr)\n            loss.backward()\n            if self.args.gclip > 0:\n                utils.clip_grad_value_(\n                    self.model.parameters(),\n                    self.args.gclip\n                )\n            self.optimizer.step()\n\n            timer_model.hold()\n\n            if (batch + 1) % self.args.print_every == 0:\n                self.ckp.write_log('[{}/{}]\\t{}\\t{:.1f}+{:.1f}s'.format(\n                    (batch + 1) * self.args.batch_size,\n                    len(self.loader_train.dataset),\n                    self.loss.display_loss(batch),\n                    timer_model.release(),\n                    timer_data.release()))\n\n            timer_data.tic()\n\n        self.loss.end_log(len(self.loader_train))\n        self.error_last = self.loss.log[-1, -1]\n        self.optimizer.schedule()\n\n    def test(self):\n        torch.set_grad_enabled(False)\n\n        epoch = self.optimizer.get_last_epoch()\n        #print(epoch)\n        self.ckp.write_log('\\nEvaluation:')\n        self.ckp.add_log(\n            torch.zeros(1, len(self.loader_test), len(self.scale))\n        )\n        self.model.eval()\n\n        timer_test = utility.timer()\n        if self.args.save_results: self.ckp.begin_background()\n        for idx_data, d in enumerate(self.loader_test):\n            for idx_scale, scale in enumerate(self.scale):\n                d.dataset.set_scale(idx_scale)\n                for lr, hr, filename in tqdm(d, ncols=80):\n                    lr, hr = self.prepare(lr, hr)\n                    sr = self.model(lr, idx_scale)\n                    sr = utility.quantize(sr, self.args.rgb_range)\n\n                    save_list = [sr]\n                    self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(\n                        sr, hr, scale, self.args.rgb_range, dataset=d\n                    )\n                    if self.args.save_gt:\n                        save_list.extend([lr, hr])\n\n                    if self.args.save_results:\n                        self.ckp.save_results(d, filename[0], save_list, scale)\n\n                self.ckp.log[-1, idx_data, idx_scale] /= len(d)\n                best = self.ckp.log.max(0)\n                self.ckp.write_log(\n                    '[{} x{}]\\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(\n                        d.dataset.name,\n                        scale,\n                        self.ckp.log[-1, idx_data, idx_scale],\n                        best[0][idx_data, idx_scale],\n                        best[1][idx_data, idx_scale]\n                    )\n                )\n\n        self.ckp.write_log('Forward: {:.2f}s\\n'.format(timer_test.toc()))\n        self.ckp.write_log('Saving...')\n\n        if self.args.save_results:\n            self.ckp.end_background()\n\n        if not self.args.test_only:\n            self.ckp.save(self, epoch, is_best=(best[1][0, 0] == epoch))\n\n        self.ckp.write_log(\n            'Total: {:.2f}s\\n'.format(timer_test.toc()), refresh=True\n        )\n\n        torch.set_grad_enabled(True)\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda')\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        return [_prepare(a) for a in args]\n\n    def terminate(self):\n        if self.args.test_only:\n            self.test()\n            return True\n        else:\n            epoch = self.optimizer.get_last_epoch() + 1\n            return epoch >= self.args.epochs\n\n"
  },
  {
    "path": "src/utility.py",
    "content": "import os\nimport math\nimport time\nimport datetime\nfrom multiprocessing import Process\nfrom multiprocessing import Queue\n\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\nimport numpy as np\nimport imageio\n\nimport torch\nimport torch.optim as optim\nimport torch.optim.lr_scheduler as lrs\n\nclass timer():\n    def __init__(self):\n        self.acc = 0\n        self.tic()\n\n    def tic(self):\n        self.t0 = time.time()\n\n    def toc(self, restart=False):\n        diff = time.time() - self.t0\n        if restart: self.t0 = time.time()\n        return diff\n\n    def hold(self):\n        self.acc += self.toc()\n\n    def release(self):\n        ret = self.acc\n        self.acc = 0\n\n        return ret\n\n    def reset(self):\n        self.acc = 0\n\nclass checkpoint():\n    def __init__(self, args):\n        self.args = args\n        self.ok = True\n        self.log = torch.Tensor()\n        now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')\n\n        if not args.load:\n            if not args.save:\n                args.save = now\n            self.dir = os.path.join('..', 'experiment', args.save)\n        else:\n            self.dir = os.path.join('..', 'experiment', args.load)\n            if os.path.exists(self.dir):\n                self.log = torch.load(self.get_path('psnr_log.pt'))\n                print('Continue from epoch {}...'.format(len(self.log)))\n            else:\n                args.load = ''\n\n        if args.reset:\n            os.system('rm -rf ' + self.dir)\n            args.load = ''\n\n        os.makedirs(self.dir, exist_ok=True)\n        os.makedirs(self.get_path('model'), exist_ok=True)\n        for d in args.data_test:\n            os.makedirs(self.get_path('results-{}'.format(d)), exist_ok=True)\n\n        open_type = 'a' if os.path.exists(self.get_path('log.txt'))else 'w'\n        self.log_file = open(self.get_path('log.txt'), open_type)\n        with open(self.get_path('config.txt'), open_type) as f:\n            f.write(now + '\\n\\n')\n            for arg in vars(args):\n                f.write('{}: {}\\n'.format(arg, getattr(args, arg)))\n            f.write('\\n')\n\n        self.n_processes = 8\n\n    def get_path(self, *subdir):\n        return os.path.join(self.dir, *subdir)\n\n    def save(self, trainer, epoch, is_best=False):\n        trainer.model.save(self.get_path('model'), epoch, is_best=is_best)\n        trainer.loss.save(self.dir)\n        #trainer.loss.plot_loss(self.dir, epoch)\n\n        #self.plot_psnr(epoch)\n        trainer.optimizer.save(self.dir)\n        torch.save(self.log, self.get_path('psnr_log.pt'))\n\n    def add_log(self, log):\n        self.log = torch.cat([self.log, log])\n\n    def write_log(self, log, refresh=False):\n        print(log)\n        self.log_file.write(log + '\\n')\n        if refresh:\n            self.log_file.close()\n            self.log_file = open(self.get_path('log.txt'), 'a')\n\n    def done(self):\n        self.log_file.close()\n\n    def plot_psnr(self, epoch):\n        axis = np.linspace(1, epoch, epoch)\n        for idx_data, d in enumerate(self.args.data_test):\n            label = 'SR on {}'.format(d)\n            fig = plt.figure()\n            plt.title(label)\n            for idx_scale, scale in enumerate(self.args.scale):\n                plt.plot(\n                    axis,\n                    self.log[:, idx_data, idx_scale].numpy(),\n                    label='Scale {}'.format(scale)\n                )\n            plt.legend()\n            plt.xlabel('Epochs')\n            plt.ylabel('PSNR')\n            plt.grid(True)\n            plt.savefig(self.get_path('test_{}.pdf'.format(d)))\n            plt.close(fig)\n\n    def begin_background(self):\n        self.queue = Queue()\n\n        def bg_target(queue):\n            while True:\n                if not queue.empty():\n                    filename, tensor = queue.get()\n                    if filename is None: break\n                    imageio.imwrite(filename, tensor.numpy())\n        \n        self.process = [\n            Process(target=bg_target, args=(self.queue,)) \\\n            for _ in range(self.n_processes)\n        ]\n        \n        for p in self.process: p.start()\n\n    def end_background(self):\n        for _ in range(self.n_processes): self.queue.put((None, None))\n        while not self.queue.empty(): time.sleep(1)\n        for p in self.process: p.join()\n\n    def save_results(self, dataset, filename, save_list, scale):\n        if self.args.save_results:\n            filename = self.get_path(\n                'results-{}'.format(dataset.dataset.name),\n                '{}_x{}_'.format(filename, scale)\n            )\n\n            postfix = ('SR', 'LR', 'HR')\n            for v, p in zip(save_list, postfix):\n                normalized = v[0].mul(255 / self.args.rgb_range)\n                tensor_cpu = normalized.byte().permute(1, 2, 0).cpu()\n                self.queue.put(('{}{}.png'.format(filename, p), tensor_cpu))\n\ndef quantize(img, rgb_range):\n    pixel_range = 255 / rgb_range\n    return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)\n\ndef calc_psnr(sr, hr, scale, rgb_range, dataset=None):\n    if hr.nelement() == 1: return 0\n\n    diff = (sr - hr) / rgb_range\n    if dataset and dataset.dataset.benchmark:\n        shave = scale\n        if diff.size(1) > 1:\n            gray_coeffs = [65.738, 129.057, 25.064]\n            convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256\n            diff = diff.mul(convert).sum(dim=1)\n    else:\n        shave = scale + 6\n\n    valid = diff[..., shave:-shave, shave:-shave]\n    mse = valid.pow(2).mean()\n\n    return -10 * math.log10(mse)\n\ndef make_optimizer(args, target):\n    '''\n        make optimizer and scheduler together\n    '''\n    # optimizer\n    trainable = filter(lambda x: x.requires_grad, target.parameters())\n    kwargs_optimizer = {'lr': args.lr, 'weight_decay': args.weight_decay}\n\n    if args.optimizer == 'SGD':\n        optimizer_class = optim.SGD\n        kwargs_optimizer['momentum'] = args.momentum\n    elif args.optimizer == 'ADAM':\n        optimizer_class = optim.Adam\n        kwargs_optimizer['betas'] = args.betas\n        kwargs_optimizer['eps'] = args.epsilon\n    elif args.optimizer == 'RMSprop':\n        optimizer_class = optim.RMSprop\n        kwargs_optimizer['eps'] = args.epsilon\n\n    # scheduler\n    milestones = list(map(lambda x: int(x), args.decay.split('-')))\n    kwargs_scheduler = {'milestones': milestones, 'gamma': args.gamma}\n    scheduler_class = lrs.MultiStepLR\n\n    class CustomOptimizer(optimizer_class):\n        def __init__(self, *args, **kwargs):\n            super(CustomOptimizer, self).__init__(*args, **kwargs)\n\n        def _register_scheduler(self, scheduler_class, **kwargs):\n            self.scheduler = scheduler_class(self, **kwargs)\n\n        def save(self, save_dir):\n            torch.save(self.state_dict(), self.get_dir(save_dir))\n\n        def load(self, load_dir, epoch=1):\n            self.load_state_dict(torch.load(self.get_dir(load_dir)))\n            if epoch > 1:\n                for _ in range(epoch): self.scheduler.step()\n\n        def get_dir(self, dir_path):\n            return os.path.join(dir_path, 'optimizer.pt')\n\n        def schedule(self):\n            self.scheduler.step()\n\n        def get_lr(self):\n            return self.scheduler.get_lr()[0]\n\n        def get_last_epoch(self):\n            return self.scheduler.last_epoch\n    \n    optimizer = CustomOptimizer(trainable, **kwargs_optimizer)\n    optimizer._register_scheduler(scheduler_class, **kwargs_scheduler)\n    return optimizer\n\n"
  },
  {
    "path": "src/videotester.py",
    "content": "import os\nimport math\n\nimport utility\nfrom data import common\n\nimport torch\nimport cv2\n\nfrom tqdm import tqdm\n\nclass VideoTester():\n    def __init__(self, args, my_model, ckp):\n        self.args = args\n        self.scale = args.scale\n\n        self.ckp = ckp\n        self.model = my_model\n\n        self.filename, _ = os.path.splitext(os.path.basename(args.dir_demo))\n\n    def test(self):\n        torch.set_grad_enabled(False)\n\n        self.ckp.write_log('\\nEvaluation on video:')\n        self.model.eval()\n\n        timer_test = utility.timer()\n        for idx_scale, scale in enumerate(self.scale):\n            vidcap = cv2.VideoCapture(self.args.dir_demo)\n            total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))\n            vidwri = cv2.VideoWriter(\n                self.ckp.get_path('{}_x{}.avi'.format(self.filename, scale)),\n                cv2.VideoWriter_fourcc(*'XVID'),\n                vidcap.get(cv2.CAP_PROP_FPS),\n                (\n                    int(scale * vidcap.get(cv2.CAP_PROP_FRAME_WIDTH)),\n                    int(scale * vidcap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n                )\n            )\n\n            tqdm_test = tqdm(range(total_frames), ncols=80)\n            for _ in tqdm_test:\n                success, lr = vidcap.read()\n                if not success: break\n\n                lr, = common.set_channel(lr, n_channels=self.args.n_colors)\n                lr, = common.np2Tensor(lr, rgb_range=self.args.rgb_range)\n                lr, = self.prepare(lr.unsqueeze(0))\n                sr = self.model(lr, idx_scale)\n                sr = utility.quantize(sr, self.args.rgb_range).squeeze(0)\n\n                normalized = sr * 255 / self.args.rgb_range\n                ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()\n                vidwri.write(ndarr)\n\n            vidcap.release()\n            vidwri.release()\n\n        self.ckp.write_log(\n            'Total: {:.2f}s\\n'.format(timer_test.toc()), refresh=True\n        )\n        torch.set_grad_enabled(True)\n\n    def prepare(self, *args):\n        device = torch.device('cpu' if self.args.cpu else 'cuda')\n        def _prepare(tensor):\n            if self.args.precision == 'half': tensor = tensor.half()\n            return tensor.to(device)\n\n        return [_prepare(a) for a in args]\n\n"
  }
]